| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119 |
- //===------ SemaDXR.cpp - Semantic Analysis for DXR shader -----*- C++ -*-===//
- ///////////////////////////////////////////////////////////////////////////////
- // //
- // SemaDXR.cpp //
- // Copyright (C) Nvidia Corporation. All rights reserved. //
- // This file is distributed under the University of Illinois Open Source //
- // License. See LICENSE.TXT for details. //
- // //
- // This file defines the semantic support for DXR. //
- // //
- ///////////////////////////////////////////////////////////////////////////////
- #include "clang/AST/ASTContext.h"
- #include "clang/AST/Attr.h"
- #include "clang/AST/Decl.h"
- #include "clang/AST/DeclCXX.h"
- #include "clang/AST/DeclTemplate.h"
- #include "clang/AST/Expr.h"
- #include "clang/AST/ExprCXX.h"
- #include "clang/AST/ExternalASTSource.h"
- #include "clang/AST/RecursiveASTVisitor.h"
- #include "clang/Sema/SemaHLSL.h"
- #include "clang/Analysis/Analyses/Dominators.h"
- #include "clang/Analysis/Analyses/ReachableCode.h"
- #include "clang/Analysis/CFG.h"
- #include "llvm/ADT/BitVector.h"
- #include "dxc/DXIL/DxilConstants.h"
- using namespace clang;
- using namespace sema;
- using namespace hlsl;
- namespace {
- struct PayloadUse {
- PayloadUse() = default;
- PayloadUse(const Stmt *S, const CFGBlock *Parent)
- : S(S), Parent(Parent), Member(nullptr) {}
- PayloadUse(const Stmt *S, const CFGBlock *Parent, const MemberExpr *Member)
- : S(S), Parent(Parent), Member(Member) {}
- bool operator<(const PayloadUse &Other) const { return S < Other.S; }
- const Stmt *S = nullptr;
- const CFGBlock *Parent = nullptr;
- const MemberExpr *Member = nullptr;
- };
- struct TraceRayCall {
- TraceRayCall() = default;
- TraceRayCall(const CallExpr *Call, const CFGBlock *Parent)
- : Call(Call), Parent(Parent) {}
- const CallExpr *Call = nullptr;
- const CFGBlock *Parent = nullptr;
- };
- struct PayloadAccessInfo {
- PayloadAccessInfo() = default;
- PayloadAccessInfo(const MemberExpr *Member, const CallExpr *Call,
- bool IsLValue)
- : Member(Member), Call(Call), IsLValue(IsLValue) {}
- const MemberExpr *Member = nullptr;
- const CallExpr *Call = nullptr;
- bool IsLValue = false;
- };
- struct DxrShaderDiagnoseInfo {
- const FunctionDecl *funcDecl;
- const VarDecl *Payload;
- DXIL::PayloadAccessShaderStage Stage;
- std::vector<TraceRayCall> TraceCalls;
- std::map<const FieldDecl *, std::vector<PayloadUse>> WritesPerField;
- std::map<const FieldDecl *, std::vector<PayloadUse>> ReadsPerField;
- std::vector<PayloadUse> PayloadAsCallArg;
- };
- std::vector<const FieldDecl *>
- DiagnosePayloadAccess(Sema &S, DxrShaderDiagnoseInfo &Info,
- const std::set<const FieldDecl *> &FieldsToIgnoreRead,
- const std::set<const FieldDecl *> &FieldsToIgnoreWrite,
- std::set<const FunctionDecl *> VisitedFunctions);
- const Stmt *IgnoreParensAndDecay(const Stmt *S);
- // Transform the shader stage to string to be used in diagnostics
- StringRef GetStringForShaderStage(DXIL::PayloadAccessShaderStage Stage) {
- StringRef StageNames[] = {"caller", "closesthit", "miss", "anyhit"};
- if (Stage != DXIL::PayloadAccessShaderStage::Invalid)
- return StageNames[static_cast<unsigned>(Stage)];
- return "";
- }
- // Returns the Qualifier for a Field and a given shader stage.
- DXIL::PayloadAccessQualifier
- GetPayloadQualifierForStage(FieldDecl *Field,
- DXIL::PayloadAccessShaderStage Stage) {
- bool hasRead = false;
- bool hasWrite = false;
- for (UnusualAnnotation *annotation : Field->getUnusualAnnotations()) {
- if (auto *payloadAnnotation =
- dyn_cast<hlsl::PayloadAccessAnnotation>(annotation)) {
- for (auto &ShaderStage : payloadAnnotation->ShaderStages) {
- if (ShaderStage != Stage)
- continue;
- hasRead |=
- payloadAnnotation->qualifier == DXIL::PayloadAccessQualifier::Read;
- hasWrite |=
- payloadAnnotation->qualifier == DXIL::PayloadAccessQualifier::Write;
- }
- }
- }
- if (hasRead && hasWrite)
- return DXIL::PayloadAccessQualifier::ReadWrite;
- if (hasRead)
- return DXIL::PayloadAccessQualifier::Read;
- if (hasWrite)
- return DXIL::PayloadAccessQualifier::Write;
- return DXIL::PayloadAccessQualifier::NoAccess;
- }
- // Returns the declaration of the payload used in a TraceRay call
- const VarDecl *GetPayloadParameterForTraceCall(const CallExpr *Trace) {
- const Decl *callee = Trace->getCalleeDecl();
- if (!callee)
- return nullptr;
- if (!isa<FunctionDecl>(callee))
- return nullptr;
- const FunctionDecl *FD = cast<FunctionDecl>(callee);
- if (FD->isImplicit() && FD->getName() == "TraceRay") {
- const Stmt *Param = IgnoreParensAndDecay(Trace->getArg(7));
- if (const DeclRefExpr *ParamRef = dyn_cast<DeclRefExpr>(Param)) {
- if (const VarDecl *Decl = dyn_cast<VarDecl>(ParamRef->getDecl()))
- return Decl;
- }
- }
- return nullptr;
- }
- // Recursively extracts accesses to a payload struct from a Stmt
- void GetPayloadAccesses(const Stmt *S, const DxrShaderDiagnoseInfo &Info,
- std::vector<PayloadAccessInfo> &Accesses, bool IsLValue,
- const MemberExpr *Member, const CallExpr *Call) {
- for (auto C : S->children()) {
- if (!C)
- continue;
- if (const DeclRefExpr *Ref = dyn_cast<DeclRefExpr>(C)) {
- if (Ref->getDecl() == Info.Payload) {
- Accesses.push_back(PayloadAccessInfo{Member, Call, IsLValue});
- }
- return;
- }
- if (const ImplicitCastExpr *Cast = dyn_cast<ImplicitCastExpr>(C)) {
- if (Cast->getCastKind() == CK_LValueToRValue) {
- IsLValue = false;
- }
- }
- GetPayloadAccesses(C, Info, Accesses, IsLValue,
- Member ? Member : dyn_cast<MemberExpr>(C),
- Call ? Call : dyn_cast<CallExpr>(C));
- }
- }
- // Collects all reads, writes and calls with participation of the payload.
- void CollectReadsWritesAndCallsForPayload(const Stmt *S,
- DxrShaderDiagnoseInfo &Info,
- const CFGBlock *Block) {
- std::vector<PayloadAccessInfo> PayloadAccesses;
- GetPayloadAccesses(S, Info, PayloadAccesses, true, dyn_cast<MemberExpr>(S),
- dyn_cast<CallExpr>(S));
- for (auto &Access : PayloadAccesses) {
- // An access to a payload member was found.
- if (Access.Member) {
- FieldDecl *Field = cast<FieldDecl>(Access.Member->getMemberDecl());
- if (Access.IsLValue) {
- Info.WritesPerField[Field].push_back(
- PayloadUse{S, Block, Access.Member});
- } else {
- Info.ReadsPerField[Field].push_back(PayloadUse{S, Block, Access.Member});
- }
- } else if (Access.Call) {
- Info.PayloadAsCallArg.push_back(PayloadUse{S, Block});
- }
- }
- }
- // Collects all TraceRay calls.
- void CollectTraceRayCalls(const Stmt *S, DxrShaderDiagnoseInfo &Info,
- const CFGBlock *Block) {
- // TraceRay has void as return type so it should never be something else
- // than a plain CallExpr.
- if (const CallExpr *Call = dyn_cast<CallExpr>(S)) {
- const Decl *Callee = Call->getCalleeDecl();
- if (!Callee || !isa<FunctionDecl>(Callee))
- return;
- const FunctionDecl *CalledFunction = cast<FunctionDecl>(Callee);
- // Ignore trace calls here.
- if (CalledFunction->isImplicit() &&
- CalledFunction->getName() == "TraceRay") {
- Info.TraceCalls.push_back({Call, Block});
- }
- }
- }
- // Find the last write to the payload field in the given block.
- PayloadUse GetLastWriteInBlock(CFGBlock &Block,
- ArrayRef<PayloadUse> PayloadWrites) {
- PayloadUse LastWrite;
- for (auto &Element : Block) { // TODO: reverse iterate?
- if (Optional<CFGStmt> S = Element.getAs<CFGStmt>()) {
- auto It =
- std::find_if(PayloadWrites.begin(), PayloadWrites.end(),
- [&](const PayloadUse &V) { return V.S == S->getStmt(); });
- if (It != std::end(PayloadWrites)) {
- LastWrite = *It;
- LastWrite.Parent = &Block;
- }
- }
- }
- return LastWrite;
- }
- // Travers the CFG until every path has reached a write or the ENTRY.
- void TraverseCFGUntilWrite(CFGBlock &Current, std::vector<PayloadUse> &Writes,
- ArrayRef<PayloadUse> PayloadWrites,
- std::set<const CFGBlock *> &Visited) {
- if (Visited.count(&Current))
- return;
- Visited.insert(&Current);
- for (auto I = Current.pred_begin(), E = Current.pred_end(); I != E; ++I) {
- CFGBlock *Pred = *I;
- if (!Pred)
- continue;
- PayloadUse WriteInPred = GetLastWriteInBlock(*Pred, PayloadWrites);
- if (!WriteInPred.S)
- TraverseCFGUntilWrite(*Pred, Writes, PayloadWrites, Visited);
- else
- Writes.push_back(WriteInPred);
- }
- }
- // Traverse the CFG from the EXIT backwards and stop as soon as a block has a
- // write to the payload field.
- std::vector<PayloadUse>
- GetAllWritesReachingExit(CFG &ShaderCFG, ArrayRef<PayloadUse> PayloadWrites) {
- std::vector<PayloadUse> Writes;
- CFGBlock &Exit = ShaderCFG.getExit();
- std::set<const CFGBlock *> Visited;
- TraverseCFGUntilWrite(Exit, Writes, PayloadWrites, Visited);
- return Writes;
- }
- // Find the first read to the payload field in the given block.
- PayloadUse GetFirstReadInBlock(CFGBlock &Block,
- ArrayRef<PayloadUse> PayloadReads) {
- PayloadUse FirstRead;
- for (auto &Element : Block) {
- if (Optional<CFGStmt> S = Element.getAs<CFGStmt>()) {
- auto It =
- std::find_if(PayloadReads.begin(), PayloadReads.end(),
- [&](const PayloadUse &V) { return V.S == S->getStmt(); });
- if (It != std::end(PayloadReads)) {
- FirstRead = *It;
- FirstRead.Parent = &Block;
- break; // We found the first read and are done with this block.
- }
- }
- }
- return FirstRead;
- }
- // Travers the CFG until every path has reached a read or the EXIT.
- void TraverseCFGUntilRead(CFGBlock &Current, std::vector<PayloadUse> &Reads,
- ArrayRef<PayloadUse> PayloadWrites,
- std::set<const CFGBlock *> &Visited) {
- if (Visited.count(&Current))
- return;
- Visited.insert(&Current);
- for (auto I = Current.succ_begin(), E = Current.succ_end(); I != E; ++I) {
- CFGBlock *Succ = *I;
- if (!Succ)
- continue;
- PayloadUse ReadInSucc = GetFirstReadInBlock(*Succ, PayloadWrites);
- if (!ReadInSucc.S)
- TraverseCFGUntilRead(*Succ, Reads, PayloadWrites, Visited);
- else
- Reads.push_back(ReadInSucc);
- }
- }
- // Traverse the CFG from the ENTRY down and stop as soon as a block has a read
- // to the payload field.
- std::vector<PayloadUse>
- GetAllReadsReachedFromEntry(CFG &ShaderCFG, ArrayRef<PayloadUse> PayloadReads) {
- std::vector<PayloadUse> Reads;
- CFGBlock &Entry = ShaderCFG.getEntry();
- std::set<const CFGBlock *> Visited;
- TraverseCFGUntilRead(Entry, Reads, PayloadReads, Visited);
- return Reads;
- }
- // Returns the record type of a payload declaration.
- CXXRecordDecl *GetPayloadType(const VarDecl *Payload) {
- auto PayloadType = Payload->getType();
- if (PayloadType->isStructureOrClassType()) {
- return PayloadType->getAsCXXRecordDecl();
- }
- return nullptr;
- }
- std::vector<FieldDecl*> GetAllPayloadFields(RecordDecl* PayloadType) {
- std::vector<FieldDecl*> PayloadFields;
- for (FieldDecl *Field : PayloadType->fields()) {
- QualType FieldType = Field->getType();
- if (RecordDecl *FieldRecordDecl = FieldType->getAsCXXRecordDecl()) {
- // Skip nested payload types.
- if (FieldRecordDecl->hasAttr<HLSLRayPayloadAttr>()) {
- auto SubTypeFields = GetAllPayloadFields(FieldRecordDecl);
- PayloadFields.insert(PayloadFields.end(), SubTypeFields.begin(), SubTypeFields.end());
- continue;
- }
- }
- PayloadFields.push_back(Field);
- }
- return PayloadFields;
- }
- // Returns true if the field is writeable in an earlier shader stage.
- bool IsFieldWriteableInEarlierStage(FieldDecl *Field,
- DXIL::PayloadAccessShaderStage ThisStage) {
- bool isWriteableInEarlierStage = false;
- switch (ThisStage) {
- case DXIL::PayloadAccessShaderStage::Anyhit:
- case DXIL::PayloadAccessShaderStage::Closesthit:
- case DXIL::PayloadAccessShaderStage::Miss: {
- auto Qualifier = GetPayloadQualifierForStage(
- Field, DXIL::PayloadAccessShaderStage::Caller);
- isWriteableInEarlierStage =
- Qualifier == DXIL::PayloadAccessQualifier::Write ||
- Qualifier == DXIL::PayloadAccessQualifier::ReadWrite;
- Qualifier = GetPayloadQualifierForStage(
- Field, DXIL::PayloadAccessShaderStage::Anyhit);
- isWriteableInEarlierStage |=
- Qualifier == DXIL::PayloadAccessQualifier::Write ||
- Qualifier == DXIL::PayloadAccessQualifier::ReadWrite;
- } break;
- default:
- break;
- }
- return isWriteableInEarlierStage;
- }
- // Emit warnings on payload writes.
- void DiagnosePayloadWrites(Sema &S, CFG &ShaderCFG, DominatorTree &DT,
- const DxrShaderDiagnoseInfo &Info,
- ArrayRef<FieldDecl *> NonWriteableFields,
- RecordDecl *PayloadType) {
- for (FieldDecl *Field : NonWriteableFields) {
- auto WritesToField = Info.WritesPerField.find(Field);
- if (WritesToField == Info.WritesPerField.end())
- continue;
- const auto &WritesToDiagnose =
- GetAllWritesReachingExit(ShaderCFG, WritesToField->second);
- for (auto &Write : WritesToDiagnose) {
- FieldDecl *MemField = cast<FieldDecl>(Write.Member->getMemberDecl());
- auto Qualifier = GetPayloadQualifierForStage(MemField, Info.Stage);
- if (Qualifier != DXIL::PayloadAccessQualifier::Write &&
- Qualifier != DXIL::PayloadAccessQualifier::ReadWrite) {
- S.Diag(Write.Member->getExprLoc(), diag::warn_hlsl_payload_access_write_loss)
- << Field->getName() << GetStringForShaderStage(Info.Stage);
- }
- }
- }
- // Check if a field is not unconditionally written and a write form an earlier
- // stage will be lost.
- auto PayloadFields = GetAllPayloadFields(PayloadType);
- for (FieldDecl *Field : PayloadFields) {
- auto Qualifier = GetPayloadQualifierForStage(Field, Info.Stage);
- if (IsFieldWriteableInEarlierStage(Field, Info.Stage) &&
- Qualifier == DXIL::PayloadAccessQualifier::Write) {
- // The field is writeable in an earlier stage and pure write in this
- // stage. Check if we find a write that dominates the exit of the
- // function.
- bool fieldHasDominatingWrite = false;
- auto It = Info.WritesPerField.find(Field);
- if (It != Info.WritesPerField.end()) {
- for (auto &Write : It->second) {
- fieldHasDominatingWrite =
- DT.dominates(Write.Parent, &ShaderCFG.getExit());
- if (fieldHasDominatingWrite)
- break;
- }
- }
- if (!fieldHasDominatingWrite) {
- S.Diag(Info.Payload->getLocation(),
- diag::warn_hlsl_payload_access_data_loss)
- << Field->getName() << GetStringForShaderStage(Info.Stage);
- }
- }
- }
- }
- // Returns true if A is earlier than B in Parent
- bool IsEarlierStatementAs(const Stmt *A, const Stmt *B,
- const CFGBlock &Parent) {
- for (auto Element : Parent) {
- if (auto S = Element.getAs<CFGStmt>()) {
- if (S->getStmt() == A)
- return true;
- if (S->getStmt() == B)
- return false;
- }
- }
- return true;
- }
- // Returns true if the write dominates payload use.
- template <typename T>
- bool WriteDominatesUse(const PayloadUse &Write, const T &Use,
- DominatorTree &DT) {
- if (Use.Parent == Write.Parent) {
- // Use and write are in the same Block.
- return IsEarlierStatementAs(Write.S, Use.S, *Use.Parent);
- }
- return DT.dominates(Write.Parent, Use.Parent);
- }
- // Emit warnings for payload reads.
- void DiagnosePayloadReads(Sema &S, CFG &ShaderCFG, DominatorTree &DT,
- const DxrShaderDiagnoseInfo &Info,
- ArrayRef<FieldDecl *> NonReadableFields) {
- for (FieldDecl *Field : NonReadableFields) {
- auto ReadsFromField = Info.ReadsPerField.find(Field);
- if (ReadsFromField == Info.ReadsPerField.end())
- continue;
- auto WritesToField = Info.WritesPerField.find(Field);
- bool FieldHasWrites = WritesToField != Info.WritesPerField.end();
- const auto &ReadsToDiagnose =
- GetAllReadsReachedFromEntry(ShaderCFG, ReadsFromField->second);
- for (auto &Read : ReadsToDiagnose) {
- bool ReadIsDominatedByWrite = false;
- if (FieldHasWrites) {
- // We found a read to a field that needs diagnose.
- // We do not want to warn about fields that read but are dominated by a
- // write. Find writes that dominate the read. If we found one, ignore
- // the read.
- for (auto Write : WritesToField->second) {
- ReadIsDominatedByWrite = WriteDominatesUse(Write, Read, DT);
- if (ReadIsDominatedByWrite)
- break;
- }
- }
- if (ReadIsDominatedByWrite)
- continue;
- FieldDecl *MemField = cast<FieldDecl>(Read.Member->getMemberDecl());
- auto Qualifier = GetPayloadQualifierForStage(MemField, Info.Stage);
- if (Qualifier != DXIL::PayloadAccessQualifier::Read &&
- Qualifier != DXIL::PayloadAccessQualifier::ReadWrite) {
- S.Diag(Read.Member->getExprLoc(), diag::warn_hlsl_payload_access_undef_read)
- << Field->getName() << GetStringForShaderStage(Info.Stage);
- }
- }
- }
- }
- // Generic CFG traversal that performs PerElementAction on every Stmt in the
- // CFG.
- template <bool Backward, typename Action>
- void TraverseCFG(const CFGBlock &Block, Action PerElementAction,
- std::set<const CFGBlock *> &Visited) {
- if (Visited.count(&Block))
- return;
- Visited.insert(&Block);
- for (const auto &Element : Block) {
- PerElementAction(Block, Element);
- }
- if (!Backward) {
- for (auto I = Block.succ_begin(), E = Block.succ_end(); I != E; ++I) {
- CFGBlock *Succ = *I;
- if (!Succ)
- continue;
- TraverseCFG</*Backward=*/false>(*Succ, PerElementAction, Visited);
- }
- } else {
- for (auto I = Block.pred_begin(), E = Block.pred_end(); I != E; ++I) {
- CFGBlock *Pred = *I;
- if (!Pred)
- continue;
- TraverseCFG<Backward>(*Pred, PerElementAction, Visited);
- }
- }
- }
- // Forward traverse the CFG and collect calls to TraceRay.
- void ForwardTraverseCFGAndCollectTraceCalls(
- const CFGBlock &Block, DxrShaderDiagnoseInfo &Info,
- std::set<const CFGBlock *> &Visited) {
- auto Action = [&Info](const CFGBlock &Block, const CFGElement &Element) {
- if (Optional<CFGStmt> S = Element.getAs<CFGStmt>()) {
- CollectTraceRayCalls(S->getStmt(), Info, &Block);
- }
- };
- TraverseCFG<false>(Block, Action, Visited);
- }
- // Foward traverse the CFG and collect all reads and writes to the payload.
- void ForwardTraverseCFGAndCollectReadsWrites(
- const CFGBlock &StartBlock, DxrShaderDiagnoseInfo &Info,
- std::set<const CFGBlock *> &Visited) {
- auto Action = [&Info](const CFGBlock &Block, const CFGElement &Element) {
- if (Optional<CFGStmt> S = Element.getAs<CFGStmt>()) {
- CollectReadsWritesAndCallsForPayload(S->getStmt(), Info, &Block);
- }
- };
- TraverseCFG<false>(StartBlock, Action, Visited);
- }
- // Backward traverse the CFG and collect all reads and writes to the payload.
- void BackwardTraverseCFGAndCollectReadsWrites(
- const CFGBlock &StartBlock, DxrShaderDiagnoseInfo &Info,
- std::set<const CFGBlock *> &Visited) {
- auto Action = [&](const CFGBlock &Block, const CFGElement &Element) {
- if (Optional<CFGStmt> S = Element.getAs<CFGStmt>()) {
- CollectReadsWritesAndCallsForPayload(S->getStmt(), Info, &Block);
- }
- };
- TraverseCFG<true>(StartBlock, Action, Visited);
- }
- // Returns true if the Stmt uses the Payload.
- bool IsPayloadArg(const Stmt *S, const Decl *Payload) {
- if (const DeclRefExpr *Ref = dyn_cast<DeclRefExpr>(S)) {
- const Decl *Decl = Ref->getDecl();
- if (Decl == Payload)
- return true;
- }
- for (auto C : S->children()) {
- if (IsPayloadArg(C, Payload))
- return true;
- }
- return false;
- }
- bool DiagnoseCallExprForExternal(Sema &S, const FunctionDecl *FD,
- const CallExpr *CE,
- const ParmVarDecl *Payload);
- // Collects all writes that dominate a PayloadUse in a CallExpr
- // and returns a set of the Fields accessed.
- std::set<const FieldDecl *>
- CollectDominatingWritesForCall(PayloadUse &Use, DxrShaderDiagnoseInfo &Info,
- DominatorTree &DT) {
- std::set<const FieldDecl *> FieldsToIgnore;
- for (auto P : Info.WritesPerField) {
- for (auto Write : P.second) {
- bool WriteDominatesCallSite = WriteDominatesUse(Write, Use, DT);
- if (WriteDominatesCallSite) {
- FieldsToIgnore.insert(P.first);
- break;
- }
- }
- }
- return FieldsToIgnore;
- }
- // Collects all reads that are reachable from a PayloadUse in a CallExpr
- // and returns a set of the Fields accessed.
- std::set<const FieldDecl *>
- CollectReachableWritesForCall(PayloadUse &Use,
- const DxrShaderDiagnoseInfo &Info) {
- std::set<const FieldDecl *> FieldsToIgnore;
- assert(Use.Parent);
- const CFGBlock *Current = Use.Parent;
- // Traverse the CFG beginning from the block of the call and collect all
- // fields written to after the call. These fields must not be diagnosed with
- // warnings about lost writes.
- DxrShaderDiagnoseInfo TempInfo;
- TempInfo.Payload = Info.Payload;
- bool foundCall = false;
- for (auto &Element : *Current) {
- // Search for the Call in the block
- if (Optional<CFGStmt> S = Element.getAs<CFGStmt>()) {
- if (S->getStmt() == Use.S) {
- foundCall = true;
- continue;
- }
- if (foundCall)
- CollectReadsWritesAndCallsForPayload(S->getStmt(), TempInfo, Current);
- }
- }
- for (auto I = Current->succ_begin(); I != Current->succ_end(); ++I) {
- CFGBlock *Succ = *I;
- if (!Succ)
- continue;
- std::set<const CFGBlock *> Visited;
- ForwardTraverseCFGAndCollectReadsWrites(*Succ, TempInfo, Visited);
- }
- for (auto &p : TempInfo.WritesPerField)
- FieldsToIgnore.insert(p.first);
- return FieldsToIgnore;
- }
- // Emit diagnostics when the payload is used as an argument
- // in a function call.
- std::map<PayloadUse, std::vector<const FieldDecl *>>
- DiagnosePayloadAsFunctionArg(
- Sema &S, DxrShaderDiagnoseInfo &Info, DominatorTree &DT,
- const std::set<const FieldDecl *> &ParentFieldsToIgnoreRead,
- const std::set<const FieldDecl *> &ParentFieldsToIgnoreWrite,
- std::set<const FunctionDecl *> VisitedFunctions) {
- std::map<PayloadUse, std::vector<const FieldDecl *>> WrittenFieldsInCalls;
- for (PayloadUse &Use : Info.PayloadAsCallArg) {
- if (const CallExpr *Call = dyn_cast<CallExpr>(Use.S)) {
- const Decl *Callee = Call->getCalleeDecl();
- if (!Callee || !isa<FunctionDecl>(Callee))
- continue;
- const FunctionDecl *CalledFunction = cast<FunctionDecl>(Callee);
- // Ignore trace calls here.
- if (CalledFunction->isImplicit() &&
- CalledFunction->getName() == "TraceRay") {
- Info.TraceCalls.push_back(TraceRayCall{Call, Use.Parent});
- continue;
- }
- // Handle external function calls
- if (!CalledFunction->hasBody()) {
- assert(isa<ParmVarDecl>(Info.Payload));
- DiagnoseCallExprForExternal(S, CalledFunction, Call,
- cast<ParmVarDecl>(Info.Payload));
- continue;
- }
- if (VisitedFunctions.count(CalledFunction))
- return WrittenFieldsInCalls;
- VisitedFunctions.insert(CalledFunction);
- DxrShaderDiagnoseInfo CalleeInfo;
- for (unsigned i = 0; i < Call->getNumArgs(); ++i) {
- const Expr *Arg = Call->getArg(i);
- if (IsPayloadArg(Arg, Info.Payload)) {
- CalleeInfo.Payload = CalledFunction->getParamDecl(i);
- break;
- }
- }
- if (CalleeInfo.Payload) {
- CalleeInfo.funcDecl = CalledFunction;
- CalleeInfo.Stage = Info.Stage;
- auto FieldsToIgnoreRead = CollectDominatingWritesForCall(Use, Info, DT);
- auto FieldsToIgnoreWrite = CollectReachableWritesForCall(Use, Info);
- FieldsToIgnoreRead.insert(ParentFieldsToIgnoreRead.begin(),
- ParentFieldsToIgnoreRead.end());
- FieldsToIgnoreWrite.insert(ParentFieldsToIgnoreWrite.begin(),
- ParentFieldsToIgnoreWrite.end());
- WrittenFieldsInCalls[Use] =
- DiagnosePayloadAccess(S, CalleeInfo, FieldsToIgnoreRead,
- FieldsToIgnoreWrite, VisitedFunctions);
- }
- }
- }
- return WrittenFieldsInCalls;
- }
- // Collect all fields that cannot be accessed for the given shader stage.
- // This function recurses into nested payload types.
- void CollectNonAccessableFields(
- RecordDecl *PayloadType, DXIL::PayloadAccessShaderStage Stage,
- const std::set<const FieldDecl *> &FieldsToIgnoreRead,
- const std::set<const FieldDecl *> &FieldsToIgnoreWrite,
- std::vector<FieldDecl *> &NonWriteableFields,
- std::vector<FieldDecl *> &NonReadableFields) {
- for (FieldDecl *Field : PayloadType->fields()) {
- QualType FieldType = Field->getType();
- if (RecordDecl *FieldRecordDecl = FieldType->getAsCXXRecordDecl()) {
- if (FieldRecordDecl->hasAttr<HLSLRayPayloadAttr>()) {
- CollectNonAccessableFields(FieldRecordDecl, Stage, FieldsToIgnoreRead,
- FieldsToIgnoreWrite, NonWriteableFields,
- NonReadableFields);
- continue;
- }
- }
- auto Qualifier = GetPayloadQualifierForStage(Field, Stage);
- // Diagnose writes only if they are not written heigher in the call-graph.
- if (!FieldsToIgnoreWrite.count(Field)) {
- if (Qualifier != DXIL::PayloadAccessQualifier::Write &&
- Qualifier != DXIL::PayloadAccessQualifier::ReadWrite)
- NonWriteableFields.push_back(Field);
- }
- // Diagnose reads only if they have no write heigher in the call-graph.
- if (!FieldsToIgnoreRead.count(Field)) {
- if (Qualifier != DXIL::PayloadAccessQualifier::Read &&
- Qualifier != DXIL::PayloadAccessQualifier::ReadWrite)
- NonReadableFields.push_back(Field);
- }
- }
- }
- void CollectAccessableFields(RecordDecl *PayloadType,
- const std::vector<FieldDecl *> &NonWriteableFields,
- const std::vector<FieldDecl *> &NonReadableFields,
- std::vector<FieldDecl *> &WriteableFields,
- std::vector<FieldDecl *> &ReadableFields) {
- for (FieldDecl *Field : PayloadType->fields()) {
- QualType FieldType = Field->getType();
- if (RecordDecl *FieldRecordDecl = FieldType->getAsCXXRecordDecl()) {
- // Skip nested payload types.
- if (FieldRecordDecl->hasAttr<HLSLRayPayloadAttr>()) {
- CollectAccessableFields(FieldRecordDecl, NonWriteableFields,
- NonReadableFields, WriteableFields,
- ReadableFields);
- continue;
- }
- }
- if (std::find(NonWriteableFields.begin(), NonWriteableFields.end(),
- Field) == NonWriteableFields.end())
- WriteableFields.push_back(Field);
- if (std::find(NonReadableFields.begin(), NonReadableFields.end(), Field) ==
- NonReadableFields.end())
- ReadableFields.push_back(Field);
- }
- }
- // Emit diagnostics for a TraceRay call.
- void DiagnoseTraceCall(Sema &S, const VarDecl *Payload,
- const TraceRayCall &Trace, DominatorTree &DT) {
- // For each TraceRay call check if write(caller) fields are written.
- const DXIL::PayloadAccessShaderStage CallerStage =
- DXIL::PayloadAccessShaderStage::Caller;
- std::vector<FieldDecl *> WriteableFields;
- std::vector<FieldDecl *> NonWriteableFields;
- std::vector<FieldDecl *> ReadableFields;
- std::vector<FieldDecl *> NonReadableFields;
- RecordDecl *PayloadType = GetPayloadType(Payload);
- // Check if the payload type used for this trace call is a payload type
- if (!PayloadType->hasAttr<HLSLRayPayloadAttr>()) {
- S.Diag(Payload->getLocation(), diag::err_payload_requires_attribute)
- << PayloadType->getName();
- return;
- }
- CollectNonAccessableFields(PayloadType, CallerStage, {}, {},
- NonWriteableFields, NonReadableFields);
- CollectAccessableFields(PayloadType, NonWriteableFields, NonReadableFields,
- WriteableFields, ReadableFields);
- // Find all writes to Payload that reaches the Trace
- DxrShaderDiagnoseInfo TraceInfo;
- TraceInfo.Payload = Payload;
- std::set<const CFGBlock *> Visited;
- const CFGBlock *Parent = Trace.Parent;
- Visited.insert(Parent);
- // Collect payload accesses in the same block until we reach the TraceRay call
- for (auto Element : *Parent) {
- if (Optional<CFGStmt> S = Element.getAs<CFGStmt>()) {
- if (S->getStmt() == Trace.Call)
- break;
- CollectReadsWritesAndCallsForPayload(S->getStmt(), TraceInfo, Parent);
- }
- }
- for (auto I = Parent->pred_begin(); I != Parent->pred_end(); ++I) {
- CFGBlock *Pred = *I;
- if (!Pred)
- continue;
- BackwardTraverseCFGAndCollectReadsWrites(*Pred, TraceInfo, Visited);
- }
- // Warn if a writeable field has not been written.
- for (const FieldDecl *Field : WriteableFields) {
- if (!TraceInfo.WritesPerField.count(Field)) {
- S.Diag(Trace.Call->getArg(7)->getExprLoc(),
- diag::warn_hlsl_payload_access_no_write_for_trace_payload)
- << Field->getName();
- }
- }
- // Warn if a written field is not write(caller)
- for (const FieldDecl *Field : NonWriteableFields) {
- if (TraceInfo.WritesPerField.count(Field)) {
- S.Diag(
- Trace.Call->getArg(7)->getExprLoc(),
- diag::warn_hlsl_payload_access_write_but_no_write_for_trace_payload)
- << Field->getName();
- }
- }
- // After a trace call, collect all reads that are not dominated by another
- // write warn if a field is not read(caller) but the value is read (undef
- // read).
- // Discard reads/writes from backward traversal.
- TraceInfo.ReadsPerField.clear();
- TraceInfo.WritesPerField.clear();
- bool CallFound = false;
- for (auto Element : *Parent) { // TODO: reverse iterate?
- if (Optional<CFGStmt> S = Element.getAs<CFGStmt>()) {
- if (S->getStmt() == Trace.Call) {
- CallFound = true;
- continue;
- }
- if (CallFound)
- CollectReadsWritesAndCallsForPayload(S->getStmt(), TraceInfo, Parent);
- }
- }
- for (auto I = Parent->succ_begin(); I != Parent->succ_end(); ++I) {
- CFGBlock *Pred = *I;
- if (!Pred)
- continue;
- ForwardTraverseCFGAndCollectReadsWrites(*Pred, TraceInfo, Visited);
- }
- for (const FieldDecl *Field : ReadableFields) {
- if (!TraceInfo.ReadsPerField.count(Field)) {
- S.Diag(Trace.Call->getArg(7)->getExprLoc(),
- diag::warn_hlsl_payload_access_read_but_no_read_after_trace)
- << Field->getName();
- }
- }
- for (const FieldDecl *Field : NonReadableFields) {
- auto WritesToField = TraceInfo.WritesPerField.find(Field);
- bool FieldHasWrites = WritesToField != TraceInfo.WritesPerField.end();
- for (auto &Read : TraceInfo.ReadsPerField[Field]) {
- bool ReadIsDominatedByWrite = false;
- if (FieldHasWrites) {
- // We found a read to a field that needs diagnose.
- // We do not want to warn about fields that read but are dominated by
- // a write. Find writes that dominate the read. If we found one,
- // ignore the read.
- for (auto Write : WritesToField->second) {
- ReadIsDominatedByWrite = WriteDominatesUse(Write, Read, DT);
- if (ReadIsDominatedByWrite)
- break;
- }
- }
- if (ReadIsDominatedByWrite)
- continue;
- S.Diag(Read.Member->getExprLoc(),
- diag::warn_hlsl_payload_access_read_of_undef_after_trace)
- << Field->getName();
- }
- }
- }
- // Emit diagnostics for all TraceRay calls.
- void DiagnoseTraceCalls(Sema &S, CFG &ShaderCFG, DominatorTree &DT,
- DxrShaderDiagnoseInfo &Info) {
- // Collect TraceRay calls in the shader.
- std::set<const CFGBlock *> Visited;
- ForwardTraverseCFGAndCollectTraceCalls(ShaderCFG.getEntry(), Info, Visited);
- std::set<const CallExpr *> Diagnosed;
- for (const TraceRayCall &TraceCall : Info.TraceCalls) {
- if (Diagnosed.count(TraceCall.Call))
- continue;
- Diagnosed.insert(TraceCall.Call);
- const VarDecl *Payload = GetPayloadParameterForTraceCall(TraceCall.Call);
- DiagnoseTraceCall(S, Payload, TraceCall, DT);
- }
- }
- // Emit diagnostics for all access to the payload of a shader,
- // and the input to TraceRay calls.
- std::vector<const FieldDecl *>
- DiagnosePayloadAccess(Sema &S, DxrShaderDiagnoseInfo &Info,
- const std::set<const FieldDecl *> &FieldsToIgnoreRead,
- const std::set<const FieldDecl *> &FieldsToIgnoreWrite,
- std::set<const FunctionDecl *> VisitedFunctions) {
- clang::DominatorTree DT;
- AnalysisDeclContextManager AnalysisManager;
- AnalysisDeclContext *AnalysisContext =
- AnalysisManager.getContext(Info.funcDecl);
- CFG &TheCFG = *AnalysisContext->getCFG();
- DT.buildDominatorTree(*AnalysisContext);
- // Collect all Fields that gets written to return it back up through the
- // recursion.
- std::vector<const FieldDecl *> WrittenFields;
- // Skip if we are in a RayGeneration shader without payload.
- if (Info.Payload) {
- std::vector<FieldDecl *> NonWriteableFields;
- std::vector<FieldDecl *> NonReadableFields;
- RecordDecl *PayloadType = GetPayloadType(Info.Payload);
- if (!PayloadType)
- return WrittenFields;
- CollectNonAccessableFields(PayloadType, Info.Stage, FieldsToIgnoreRead,
- FieldsToIgnoreWrite, NonWriteableFields,
- NonReadableFields);
- std::set<const CFGBlock *> Visited;
- ForwardTraverseCFGAndCollectReadsWrites(TheCFG.getEntry(), Info, Visited);
- if (Info.Payload->hasAttr<HLSLOutAttr>() ||
- Info.Payload->hasAttr<HLSLInOutAttr>()) {
- // If there is copy-out semantic on the payload field,
- // save the written fields and return it back to the caller for
- // better diagnostics in higher recursion levels.
- for (auto &p : Info.WritesPerField) {
- WrittenFields.push_back(p.first);
- }
- DiagnosePayloadWrites(S, TheCFG, DT, Info, NonWriteableFields,
- PayloadType);
- }
- auto WrittenFieldsInCalls = DiagnosePayloadAsFunctionArg(
- S, Info, DT, FieldsToIgnoreRead, FieldsToIgnoreWrite, VisitedFunctions);
- // Add calls that write fields as writes to allow the diagnostics on reads
- // to check if a call that writes the field dominates the read.
- for (auto& P : WrittenFieldsInCalls) {
- for (const FieldDecl* Field : P.second) {
- Info.WritesPerField[Field].push_back(P.first);
- }
- }
- if (Info.Payload->hasAttr<HLSLInAttr>() ||
- Info.Payload->hasAttr<HLSLInOutAttr>())
- DiagnosePayloadReads(S, TheCFG, DT, Info, NonReadableFields);
- }
- DiagnoseTraceCalls(S, TheCFG, DT, Info);
- return WrittenFields;
- }
- const Stmt *IgnoreParensAndDecay(const Stmt *S) {
- for (;;) {
- switch (S->getStmtClass()) {
- case Stmt::ParenExprClass:
- S = cast<ParenExpr>(S)->getSubExpr();
- break;
- case Stmt::ImplicitCastExprClass: {
- const ImplicitCastExpr *castExpr = cast<ImplicitCastExpr>(S);
- if (castExpr->getCastKind() != CK_ArrayToPointerDecay &&
- castExpr->getCastKind() != CK_NoOp &&
- castExpr->getCastKind() != CK_LValueToRValue) {
- return S;
- }
- S = castExpr->getSubExpr();
- } break;
- default:
- return S;
- }
- }
- }
- // Emit warnings for calls that pass the payload to extern functions.
- bool DiagnoseCallExprForExternal(Sema &S, const FunctionDecl *FD,
- const CallExpr *CE,
- const ParmVarDecl *Payload) {
- // We check if we are passing the entire payload struct to an extern function.
- // Here ends what we can check, so we just issue a warning.
- if (!FD->hasBody()) {
- const DeclRefExpr *DRef = nullptr;
- const ParmVarDecl *PDecl = nullptr;
- for (unsigned i = 0; i < CE->getNumArgs(); ++i) {
- const Stmt *arg = IgnoreParensAndDecay(CE->getArg(i));
- if (const DeclRefExpr *ArgRef = dyn_cast<DeclRefExpr>(arg)) {
- if (ArgRef->getDecl() == Payload) {
- DRef = ArgRef;
- PDecl = FD->getParamDecl(i);
- break;
- }
- }
- }
- if (DRef) {
- S.Diag(CE->getExprLoc(),
- diag::warn_qualified_payload_passed_to_extern_function);
- return true;
- }
- }
- return false;
- }
- // Emits diagnostics for the Payload parameter of a DXR shader stage.
- bool DiagnosePayloadParameter(Sema &S, ParmVarDecl *Payload, FunctionDecl *FD,
- DXIL::PayloadAccessShaderStage stage) {
- if (!Payload) {
- // cought already during codgegen of the function
- return false;
- }
- if (!Payload->getAttr<HLSLInOutAttr>()) {
- // error: payload must be inout qualified
- return false;
- }
- CXXRecordDecl *Decl = Payload->getType()->getAsCXXRecordDecl();
- if (!Decl || Decl->isImplicit()) {
- // error: not a user defined type decl
- return false;
- }
- if (!Decl->hasAttr<HLSLRayPayloadAttr>()) {
- S.Diag(Payload->getLocation(), diag::err_payload_requires_attribute)
- << Decl->getName();
- return false;
- }
- return true;
- }
- class DXRShaderVisitor : public RecursiveASTVisitor<DXRShaderVisitor> {
- public:
- DXRShaderVisitor(Sema &S) : S(S) {}
- void diagnose(TranslationUnitDecl *TU) { TraverseTranslationUnitDecl(TU); }
- bool VisitFunctionDecl(FunctionDecl *Decl) {
- auto attr = Decl->getAttr<HLSLShaderAttr>();
- if (!attr)
- return true;
- StringRef shaderStage = attr->getStage();
- if (StringRef("miss,closesthit,anyhit,raygeneration").count(shaderStage)) {
- ParmVarDecl *Payload = nullptr;
- if (shaderStage != "raygeneration")
- Payload = Decl->getParamDecl(0);
- DXIL::PayloadAccessShaderStage Stage =
- DXIL::PayloadAccessShaderStage::Invalid;
- if (shaderStage == "closesthit") {
- Stage = DXIL::PayloadAccessShaderStage::Closesthit;
- } else if (shaderStage == "miss") {
- Stage = DXIL::PayloadAccessShaderStage::Miss;
- } else if (shaderStage == "anyhit") {
- Stage = DXIL::PayloadAccessShaderStage::Anyhit;
- }
- // Diagnose the payload parameter.
- if (Payload) {
- DiagnosePayloadParameter(S, Payload, Decl, Stage);
- }
- DxrShaderDiagnoseInfo Info;
- Info.funcDecl = Decl;
- Info.Payload = Payload;
- Info.Stage = Stage;
- std::set<const FunctionDecl *> VisitedFunctions;
- DiagnosePayloadAccess(S, Info, {}, {}, VisitedFunctions);
- }
- return true;
- }
- private:
- Sema &S;
- };
- } // namespace
- namespace hlsl {
- void DiagnoseRaytracingPayloadAccess(clang::Sema &S,
- clang::TranslationUnitDecl *TU) {
- DXRShaderVisitor visitor(S);
- visitor.diagnose(TU);
- }
- } // namespace hlsl
|