Sfoglia il codice sorgente

Improve option rank analysis code by support of Do-Statement and a 2weights if they appear in loops. Also fixup possibility of false-counting blocks that are related to a too distant parent. And fix the problem of very easy max-depth to overcome which made o_enableShadows cost only 7 when after going from max depth of 4 to 6 for Ast search, it now finds 2489 impactCost for that option.

Signed-off-by: Vivien Oddou <[email protected]>
Vivien Oddou 2 anni fa
parent
commit
b766ca2f73
4 ha cambiato i file con 62 aggiunte e 23 eliminazioni
  1. 1 0
      src/AzslcIntermediateRepresentation.cpp
  2. 47 22
      src/AzslcReflection.cpp
  3. 13 0
      src/AzslcUtils.h
  4. 1 1
      src/azslParser.g4

+ 1 - 0
src/AzslcIntermediateRepresentation.cpp

@@ -121,6 +121,7 @@ namespace AZ::ShaderCompiler
         // We're not going to transform our flat vector to tree and to vector again,
         // instead we'll shift around the elements to respect that order.
         // And to do so, we'll use a dependency DAG and a topological solver.
+        verboseCout << "--Middle End--\n";
 
         m_symbols.ReorderBySymbolDependency();
 

+ 47 - 22
src/AzslcReflection.cpp

@@ -1052,52 +1052,80 @@ namespace AZ::ShaderCompiler
             // only options
             if (varInfo->CheckHasStorageFlag(StorageFlag::Option))
             {
+                verboseCout << "Analyzing " << uid << "\n";
                 int impactScore = 0;
                 // loop over appearances over the program
                 for (Seenat& ref : kindInfo->GetSeenats())
                 {
+                    verboseCout << "Seenat line " << ref.m_where.m_line << "\n";
                     // determine an impact score
                     impactScore += AnalyzeImpact(ref.m_where)  // dependent code that may be skipped depending on the value of that ref
                         + 1;  // by virtue of being mentioned (seenat), we count the reference as an access of cost 1.
                 }
                 varInfo->m_estimatedCostImpact = impactScore;
+                verboseCout << uid << " final cost " << impactScore << "\n";
             }
         }
     }
 
+    template< typename CtxT >
+    bool SetNextNodeIfPartOfTypeCtxTCondViaNParents(
+        ParserRuleContext*& node,
+        int maxDepth)
+    {
+        if (auto* searchNode = DeepParentAs<CtxT*>(node, maxDepth))
+        {
+            if (IsParentOf(searchNode->Condition, node))
+            {
+                node = searchNode->embeddedStatement();
+                return true;
+            }
+        }
+        return false;
+    }
+
     int CodeReflection::AnalyzeImpact(TokensLocation const& location) const
     {
         // find the node at `location`:
         ParserRuleContext* node = m_ir->m_tokenMap.GetNode(location.m_focusedTokenId);
-        // go up tree to meet a block node that has visitable depth:
-        // can be any of if/for/while/switch
-        //  4 is an arbitrary depth, enough to search up things like `for (a, b<(ref+1), c)` binaryop->braces->cmpexpr->cond->for
-        if (auto* whileNode = DeepParentAs<azslParser::WhileStatementContext*>(node->parent, 3))
-        {
-            node = whileNode->embeddedStatement();
-        }
-        else if (auto* ifNode = DeepParentAs<azslParser::IfStatementContext*>(node->parent, 3))
-        {
-            node = ifNode->embeddedStatement();
-        }
-        else if (auto* forNode = DeepParentAs<azslParser::ForStatementContext*>(node->parent, 4))
+        // the "belonging" statements that we will consider, before recursing:
+        using AstIf     = azslParser::IfStatementContext;
+        using AstFor    = azslParser::ForStatementContext;
+        using AstWhile  = azslParser::WhileStatementContext;
+        using AstDo     = azslParser::DoStatementContext;
+        using AstSwitch = azslParser::SwitchStatementContext;
+        // go up the tree to meet one of them using arbitrary max depths of {5,6,7},
+        // just enough to search up things like `for (a, b<(ref+1), c)` idExpr->IdentifierExpression->OtherExpression->BinaryExpr->ParenthesisExpr->BinaryExpr->Condition->For
+        int complexityFactor = 1;
+        bool isNonLoop = SetNextNodeIfPartOfTypeCtxTCondViaNParents<AstIf>(node, 6);
+        if (!isNonLoop && (SetNextNodeIfPartOfTypeCtxTCondViaNParents<AstFor>(node, 7)
+                           || SetNextNodeIfPartOfTypeCtxTCondViaNParents<AstWhile>(node, 6)
+                           || SetNextNodeIfPartOfTypeCtxTCondViaNParents<AstDo>(node, 6)))
         {
-            node = forNode->embeddedStatement();
+            complexityFactor = 2; // arbitrarily augment loop scores by virtue of assuming they repeat O(N=2)
         }
-        else if (auto* switchNode = DeepParentAs<azslParser::SwitchStatementContext*>(node->parent, 3))
+        else if (auto* switchNode = DeepParentAs<AstSwitch*>(node, 5))
         {
-            node = switchNode->switchBlock();
+            if (IsParentOf(switchNode->Expr, node))
+            {
+                node = switchNode->switchBlock();
+            }
         }
         int score = 0;
         AnalyzeImpact(node, score);
-        return score;
+        return score * complexityFactor;
     }
 
     void CodeReflection::AnalyzeImpact(ParserRuleContext* astNode, int& scoreAccumulator) const
     {
         for (auto& c : astNode->children)
         {
-            if (auto* callNode = As<azslParser::FunctionCallExpressionContext*>(c))
+            if (auto* leaf = As<tree::TerminalNode*>(c))
+            {
+                // determine cost by number of full expressions separated by semicolon
+                scoreAccumulator += leaf->getSymbol()->getType() == azslLexer::Semi;  // bool as 0 or 1 trick
+            }
+            else if (auto* callNode = As<azslParser::FunctionCallExpressionContext*>(c))
             {
                 // branch into an overload specialized for function lookup:
                 AnalyzeImpact(callNode, scoreAccumulator);
@@ -1106,11 +1134,6 @@ namespace AZ::ShaderCompiler
             {
                 AnalyzeImpact(node, scoreAccumulator); // recurse down to make sure to capture embedded calls, like e.g. "x ? f() : 0;"
             }
-            if (auto* leaf = As<tree::TerminalNode*>(c))
-            {
-                // determine cost by number of full expressions separated by semicolon
-                scoreAccumulator += leaf->getSymbol()->getType() == azslLexer::Semi;  // bool as 0 or 1 trick
-            }
         }
     }
 
@@ -1173,8 +1196,10 @@ namespace AZ::ShaderCompiler
                         using AstFDef = azslParser::HlslFunctionDefinitionContext;
                         AnalyzeImpact(polymorphic_downcast<AstFDef*>(funcInfo->m_defNode->parent)->block(),
                                       funcInfo->m_costScore);  // recurse and cache
+                        verboseCout << " " << concrete << " analyzed at " << funcInfo->m_costScore << "\n";
                     }
                     scoreAccumulator += funcInfo->m_costScore;
+                    verboseCout << " " << concrete << " call score " << funcInfo->m_costScore << " added. now " << scoreAccumulator << "\n";
                 }
             }
             // other cases forfeited for now, but that would at least include things like eg braces (f)()

+ 13 - 0
src/AzslcUtils.h

@@ -958,6 +958,19 @@ namespace AZ::ShaderCompiler
         return UnqualifiedName{ctx->Name->getText()};
     }
 
+    //! Verify filiation of an AST rule
+    inline bool IsParentOf(tree::ParseTree* assumedParent, tree::ParseTree* assumedChild)
+    {
+        tree::ParseTree* parent =  assumedChild->parent;
+        while (parent)
+        {
+            if (parent == assumedParent)
+                return true;
+            parent = parent->parent;
+        }
+        return false;
+    }
+
     //! Get a pointer to the first parent that happens to be of type `SearchType`
     //! with a limit depth of `maxDepth` parents to search through
     template <typename SearchType>

+ 1 - 1
src/azslParser.g4

@@ -232,7 +232,7 @@ embeddedStatement:
     |   attributeSpecifier* Switch LeftParen Expr=expressionExt RightParen switchBlock # SwitchStatement
 
     // Iteration statement
-    |   attributeSpecifier* While LeftParen condition=expressionExt RightParen embeddedStatement # WhileStatement
+    |   attributeSpecifier* While LeftParen Condition=expressionExt RightParen embeddedStatement # WhileStatement
     |   attributeSpecifier* Do embeddedStatement While LeftParen Condition=expressionExt RightParen Semi # DoStatement
     |   attributeSpecifier* For LeftParen forInitializer? Semi Condition=expressionExt? Semi iterator=expressionExt? RightParen embeddedStatement # ForStatement