浏览代码

Merge pull request #85 from o3de/auto-option-ranks

Auto option ranks (improvement)
siliconvoodoo 2 年之前
父节点
当前提交
ae0bd5cf47

+ 1 - 0
src/AzslcIntermediateRepresentation.cpp

@@ -121,6 +121,7 @@ namespace AZ::ShaderCompiler
         // We're not going to transform our flat vector to tree and to vector again,
         // We're not going to transform our flat vector to tree and to vector again,
         // instead we'll shift around the elements to respect that order.
         // instead we'll shift around the elements to respect that order.
         // And to do so, we'll use a dependency DAG and a topological solver.
         // And to do so, we'll use a dependency DAG and a topological solver.
+        verboseCout << "--Middle End--\n";
 
 
         m_symbols.ReorderBySymbolDependency();
         m_symbols.ReorderBySymbolDependency();
 
 

+ 6 - 15
src/AzslcMain.cpp

@@ -269,12 +269,12 @@ namespace AZ::ShaderCompiler::Main
             visitOptions);
             visitOptions);
     }
     }
 
 
-    void ParseWarningLevel(const unordered_map<Warn::EnumType, bool>& args, DiagnosticStream& warningConfig)
+    void ParseWarningLevel(const std::array<bool, Warn::EndEnumeratorSentinel_>& args, DiagnosticStream& warningConfig)
     {
     {
         for (auto level : Warn::Enumerate{})
         for (auto level : Warn::Enumerate{})
         {
         {
-            auto lookup = args.find(level);
-            if (lookup != args.end() && lookup->second)
+            bool active = args[level];
+            if (active)
             {
             {
                 if (level >= Warn::Wx)
                 if (level >= Warn::Wx)
                 {
                 {
@@ -423,7 +423,7 @@ int main(int argc, const char* argv[])
     int maxSpaces = std::numeric_limits<int>::max();
     int maxSpaces = std::numeric_limits<int>::max();
     auto maxSpacesOpt = cli.add_option("--max-spaces", maxSpaces, "Will choose register spaces that do not extend past this limit.");
     auto maxSpacesOpt = cli.add_option("--max-spaces", maxSpaces, "Will choose register spaces that do not extend past this limit.");
 
 
-    std::unordered_map<Warn::EnumType, bool> warningOpts;
+    std::array<bool, Warn::EndEnumeratorSentinel_> warningOpts;
     for (const auto e : Warn::Enumerate{})
     for (const auto e : Warn::Enumerate{})
     {
     {
         warningOpts[e] = false;
         warningOpts[e] = false;
@@ -650,8 +650,6 @@ int main(int argc, const char* argv[])
             // intermediate state validation
             // intermediate state validation
             ir.Validate();
             ir.Validate();
 
 
-            bool doEmission = true;
-
             if (stripUnusedSrgs)
             if (stripUnusedSrgs)
             {
             {
                 ir.RemoveUnusedSrgs();
                 ir.RemoveUnusedSrgs();
@@ -660,7 +658,6 @@ int main(int argc, const char* argv[])
             if (dumpsym)
             if (dumpsym)
             {
             {
                 DumpSymbols(ir);
                 DumpSymbols(ir);
-                doEmission = false;
             }
             }
             else if (!visitName.empty())
             else if (!visitName.empty())
             {
             {
@@ -675,16 +672,10 @@ int main(int argc, const char* argv[])
                     visitOptions |= possibleOption.first ? possibleOption.second : RE::EnumType(0);
                     visitOptions |= possibleOption.first ? possibleOption.second : RE::EnumType(0);
                 }
                 }
                 PrintVisitSymbol(ir, visitName, visitOptions);
                 PrintVisitSymbol(ir, visitName, visitOptions);
-                doEmission = false;
             }
             }
-            else
-            {
-                bool checkerFlagsPresent = semantic || verbose; // or --syntax but we already exited by now.
-                doEmission = !checkerFlagsPresent;
-            }
-
-            if (doEmission)
+            else if (!semantic)  // do emission
             {
             {
+                verboseCout << "--Emission/Reflection--\n";
                 std::ofstream mainOutFile;
                 std::ofstream mainOutFile;
 
 
                 if (useOutputFile)
                 if (useOutputFile)

+ 47 - 22
src/AzslcReflection.cpp

@@ -1052,52 +1052,80 @@ namespace AZ::ShaderCompiler
             // only options
             // only options
             if (varInfo->CheckHasStorageFlag(StorageFlag::Option))
             if (varInfo->CheckHasStorageFlag(StorageFlag::Option))
             {
             {
+                verboseCout << "Analyzing " << uid << "\n";
                 int impactScore = 0;
                 int impactScore = 0;
                 // loop over appearances over the program
                 // loop over appearances over the program
                 for (Seenat& ref : kindInfo->GetSeenats())
                 for (Seenat& ref : kindInfo->GetSeenats())
                 {
                 {
+                    verboseCout << "Seen-at line " << ref.m_where.m_line << "\n";
                     // determine an impact score
                     // determine an impact score
                     impactScore += AnalyzeImpact(ref.m_where)  // dependent code that may be skipped depending on the value of that ref
                     impactScore += AnalyzeImpact(ref.m_where)  // dependent code that may be skipped depending on the value of that ref
                         + 1;  // by virtue of being mentioned (seenat), we count the reference as an access of cost 1.
                         + 1;  // by virtue of being mentioned (seenat), we count the reference as an access of cost 1.
                 }
                 }
                 varInfo->m_estimatedCostImpact = impactScore;
                 varInfo->m_estimatedCostImpact = impactScore;
+                verboseCout << uid << " final cost " << impactScore << "\n";
             }
             }
         }
         }
     }
     }
 
 
+    template< typename CtxT >
+    bool SetNextNodeIfChildOfCtxTCondViaNParents(
+        ParserRuleContext*& node,
+        int maxDepth)
+    {
+        if (auto* searchNode = DeepParentAs<CtxT*>(node, maxDepth))
+        {
+            if (IsParentOf(searchNode->Condition, node))
+            {
+                node = searchNode->embeddedStatement();
+                return true;
+            }
+        }
+        return false;
+    }
+
     int CodeReflection::AnalyzeImpact(TokensLocation const& location) const
     int CodeReflection::AnalyzeImpact(TokensLocation const& location) const
     {
     {
         // find the node at `location`:
         // find the node at `location`:
         ParserRuleContext* node = m_ir->m_tokenMap.GetNode(location.m_focusedTokenId);
         ParserRuleContext* node = m_ir->m_tokenMap.GetNode(location.m_focusedTokenId);
-        // go up tree to meet a block node that has visitable depth:
-        // can be any of if/for/while/switch
-        //  4 is an arbitrary depth, enough to search up things like `for (a, b<(ref+1), c)` binaryop->braces->cmpexpr->cond->for
-        if (auto* whileNode = DeepParentAs<azslParser::WhileStatementContext*>(node->parent, 3))
-        {
-            node = whileNode->embeddedStatement();
-        }
-        else if (auto* ifNode = DeepParentAs<azslParser::IfStatementContext*>(node->parent, 3))
-        {
-            node = ifNode->embeddedStatement();
-        }
-        else if (auto* forNode = DeepParentAs<azslParser::ForStatementContext*>(node->parent, 4))
+        // the "belonging" statements that we will consider, before recursing:
+        using AstIf     = azslParser::IfStatementContext;
+        using AstFor    = azslParser::ForStatementContext;
+        using AstWhile  = azslParser::WhileStatementContext;
+        using AstDo     = azslParser::DoStatementContext;
+        using AstSwitch = azslParser::SwitchStatementContext;
+        // go up the tree to meet one of them using arbitrary max depths of {5,6,7},
+        // just enough to search up things like `for (a, b<(ref+1), c)` idExpr->IdentifierExpression->OtherExpression->BinaryExpr->ParenthesisExpr->BinaryExpr->Condition->For
+        int complexityFactor = 1;
+        bool isNonLoop = SetNextNodeIfChildOfCtxTCondViaNParents<AstIf>(node, 6);
+        if (!isNonLoop && (SetNextNodeIfChildOfCtxTCondViaNParents<AstFor>(node, 7)
+                           || SetNextNodeIfChildOfCtxTCondViaNParents<AstWhile>(node, 6)
+                           || SetNextNodeIfChildOfCtxTCondViaNParents<AstDo>(node, 6)))
         {
         {
-            node = forNode->embeddedStatement();
+            complexityFactor = 2; // arbitrarily augment loop scores by virtue of assuming they repeat O(N=2)
         }
         }
-        else if (auto* switchNode = DeepParentAs<azslParser::SwitchStatementContext*>(node->parent, 3))
+        else if (auto* switchNode = DeepParentAs<AstSwitch*>(node, 5))
         {
         {
-            node = switchNode->switchBlock();
+            if (IsParentOf(switchNode->Expr, node))
+            {
+                node = switchNode->switchBlock();
+            }
         }
         }
         int score = 0;
         int score = 0;
         AnalyzeImpact(node, score);
         AnalyzeImpact(node, score);
-        return score;
+        return score * complexityFactor;
     }
     }
 
 
     void CodeReflection::AnalyzeImpact(ParserRuleContext* astNode, int& scoreAccumulator) const
     void CodeReflection::AnalyzeImpact(ParserRuleContext* astNode, int& scoreAccumulator) const
     {
     {
         for (auto& c : astNode->children)
         for (auto& c : astNode->children)
         {
         {
-            if (auto* callNode = As<azslParser::FunctionCallExpressionContext*>(c))
+            if (auto* leaf = As<tree::TerminalNode*>(c))
+            {
+                // determine cost by number of full expressions separated by semicolon
+                scoreAccumulator += leaf->getSymbol()->getType() == azslLexer::Semi;  // bool as 0 or 1 trick
+            }
+            else if (auto* callNode = As<azslParser::FunctionCallExpressionContext*>(c))
             {
             {
                 // branch into an overload specialized for function lookup:
                 // branch into an overload specialized for function lookup:
                 AnalyzeImpact(callNode, scoreAccumulator);
                 AnalyzeImpact(callNode, scoreAccumulator);
@@ -1106,11 +1134,6 @@ namespace AZ::ShaderCompiler
             {
             {
                 AnalyzeImpact(node, scoreAccumulator); // recurse down to make sure to capture embedded calls, like e.g. "x ? f() : 0;"
                 AnalyzeImpact(node, scoreAccumulator); // recurse down to make sure to capture embedded calls, like e.g. "x ? f() : 0;"
             }
             }
-            if (auto* leaf = As<tree::TerminalNode*>(c))
-            {
-                // determine cost by number of full expressions separated by semicolon
-                scoreAccumulator += leaf->getSymbol()->getType() == azslLexer::Semi;  // bool as 0 or 1 trick
-            }
         }
         }
     }
     }
 
 
@@ -1169,12 +1192,14 @@ namespace AZ::ShaderCompiler
                 {
                 {
                     if (funcInfo->m_costScore == -1)  // cost not yet discovered for this function
                     if (funcInfo->m_costScore == -1)  // cost not yet discovered for this function
                     {
                     {
+                        verboseCout << " " << concrete << " non-memoized. discovering cost\n";
                         funcInfo->m_costScore = 0;
                         funcInfo->m_costScore = 0;
                         using AstFDef = azslParser::HlslFunctionDefinitionContext;
                         using AstFDef = azslParser::HlslFunctionDefinitionContext;
                         AnalyzeImpact(polymorphic_downcast<AstFDef*>(funcInfo->m_defNode->parent)->block(),
                         AnalyzeImpact(polymorphic_downcast<AstFDef*>(funcInfo->m_defNode->parent)->block(),
                                       funcInfo->m_costScore);  // recurse and cache
                                       funcInfo->m_costScore);  // recurse and cache
                     }
                     }
                     scoreAccumulator += funcInfo->m_costScore;
                     scoreAccumulator += funcInfo->m_costScore;
+                    verboseCout << " " << concrete << " cost score " << funcInfo->m_costScore << " added\n";
                 }
                 }
             }
             }
             // other cases forfeited for now, but that would at least include things like eg braces (f)()
             // other cases forfeited for now, but that would at least include things like eg braces (f)()

+ 1 - 1
src/AzslcSemanticOrchestrator.cpp

@@ -1123,7 +1123,7 @@ namespace AZ::ShaderCompiler
             {
             {
                 auto outputStream = lhsKindInfo.GetKind() == Kind::Type ? verboseCout : warningCout;
                 auto outputStream = lhsKindInfo.GetKind() == Kind::Type ? verboseCout : warningCout;
                 // registered predefined types can have members, but we don't know them -> not important. But anything else is very likely an ill-formed source.
                 // registered predefined types can have members, but we don't know them -> not important. But anything else is very likely an ill-formed source.
-                PrintWarning(outputStream, Warn::W1, line, none, " warning: ill-formed semantics: access of member ",
+                PrintWarning(outputStream, Warn::W1, line, none, "ill-formed semantics: access of member ",
                              " on an unsupported kind ", Kind::ToStr(lhsKindInfo.GetKind()),
                              " on an unsupported kind ", Kind::ToStr(lhsKindInfo.GetKind()),
                              " (of believed type ", lhsSymbol->first.GetName(),
                              " (of believed type ", lhsSymbol->first.GetName(),
                              (lhsExpressionText ? " from expression " + *lhsExpressionText : ""), ")");
                              (lhsExpressionText ? " from expression " + *lhsExpressionText : ""), ")");

+ 15 - 0
src/AzslcUtils.h

@@ -958,6 +958,21 @@ namespace AZ::ShaderCompiler
         return UnqualifiedName{ctx->Name->getText()};
         return UnqualifiedName{ctx->Name->getText()};
     }
     }
 
 
+    //! Verify filiation of an AST rule
+    inline bool IsParentOf(tree::ParseTree* assumedParent, tree::ParseTree* assumedChild)
+    {
+        tree::ParseTree* parent =  assumedChild->parent;
+        while (parent)
+        {
+            if (parent == assumedParent)
+            {
+                return true;
+            }
+            parent = parent->parent;
+        }
+        return false;
+    }
+
     //! Get a pointer to the first parent that happens to be of type `SearchType`
     //! Get a pointer to the first parent that happens to be of type `SearchType`
     //! with a limit depth of `maxDepth` parents to search through
     //! with a limit depth of `maxDepth` parents to search through
     template <typename SearchType>
     template <typename SearchType>

+ 1 - 1
src/DiagnosticStream.h

@@ -162,6 +162,6 @@ namespace AZ
     private:
     private:
         Warn m_warningAsErrorLevel = Warn::EndEnumeratorSentinel_;  //!< no warning is an error by default
         Warn m_warningAsErrorLevel = Warn::EndEnumeratorSentinel_;  //!< no warning is an error by default
         Warn m_warningLevel = Warn::W1;                             //!< current activated level setting. default warning is W1
         Warn m_warningLevel = Warn::W1;                             //!< current activated level setting. default warning is W1
-        stack<Warn> m_activeManipulator{ {Warn::W1} };          //!< store manipulators. start with an initial value corresponding to the default filter.
+        stack<Warn> m_activeManipulator{ {Warn::W1} };              //!< store manipulators. start with an initial value corresponding to the default filter.
     };
     };
 }
 }

+ 1 - 1
src/azslParser.g4

@@ -232,7 +232,7 @@ embeddedStatement:
     |   attributeSpecifier* Switch LeftParen Expr=expressionExt RightParen switchBlock # SwitchStatement
     |   attributeSpecifier* Switch LeftParen Expr=expressionExt RightParen switchBlock # SwitchStatement
 
 
     // Iteration statement
     // Iteration statement
-    |   attributeSpecifier* While LeftParen condition=expressionExt RightParen embeddedStatement # WhileStatement
+    |   attributeSpecifier* While LeftParen Condition=expressionExt RightParen embeddedStatement # WhileStatement
     |   attributeSpecifier* Do embeddedStatement While LeftParen Condition=expressionExt RightParen Semi # DoStatement
     |   attributeSpecifier* Do embeddedStatement While LeftParen Condition=expressionExt RightParen Semi # DoStatement
     |   attributeSpecifier* For LeftParen forInitializer? Semi Condition=expressionExt? Semi iterator=expressionExt? RightParen embeddedStatement # ForStatement
     |   attributeSpecifier* For LeftParen forInitializer? Semi Condition=expressionExt? Semi iterator=expressionExt? RightParen embeddedStatement # ForStatement
 
 

+ 1 - 1
src/generated/azslParser.cpp

@@ -5680,7 +5680,7 @@ azslParser::EmbeddedStatementContext* azslParser::embeddedStatement() {
       setState(606);
       setState(606);
       match(azslParser::LeftParen);
       match(azslParser::LeftParen);
       setState(607);
       setState(607);
-      antlrcpp::downCast<WhileStatementContext *>(_localctx)->condition = expressionExt(0);
+      antlrcpp::downCast<WhileStatementContext *>(_localctx)->Condition = expressionExt(0);
       setState(608);
       setState(608);
       match(azslParser::RightParen);
       match(azslParser::RightParen);
       setState(609);
       setState(609);

+ 1 - 1
src/generated/azslParser.h

@@ -1097,7 +1097,7 @@ public:
   public:
   public:
     WhileStatementContext(EmbeddedStatementContext *ctx);
     WhileStatementContext(EmbeddedStatementContext *ctx);
 
 
-    azslParser::ExpressionExtContext *condition = nullptr;
+    azslParser::ExpressionExtContext *Condition = nullptr;
     antlr4::tree::TerminalNode *While();
     antlr4::tree::TerminalNode *While();
     antlr4::tree::TerminalNode *LeftParen();
     antlr4::tree::TerminalNode *LeftParen();
     antlr4::tree::TerminalNode *RightParen();
     antlr4::tree::TerminalNode *RightParen();