Răsfoiți Sursa

First working order of code that can track method calls cost.
Now able to locate the symbols because the starting scope is correctly reconstructed, using the new IntervalCollection class which is able to support query for non-disjointed intervals, which is a more difficult case than what we had up to now. We still keep the previous map to functions because it's faster to query.

Signed-off-by: Vivien Oddou <[email protected]>

Vivien Oddou 2 ani în urmă
părinte
comite
7042ebf23a
3 a modificat fișierele cu 44 adăugiri și 30 ștergeri
  1. 12 11
      src/AzslcReflection.cpp
  2. 5 2
      src/AzslcReflection.h
  3. 27 17
      src/GenericUtils.h

+ 12 - 11
src/AzslcReflection.cpp

@@ -914,7 +914,7 @@ namespace AZ::ShaderCompiler
 
 
         // Prepare a lookup acceleration data structure for reverse mapping tokens to scopes.
         // Prepare a lookup acceleration data structure for reverse mapping tokens to scopes.
         // (truth: we need a set of disjoint intervals as an invariant for the following algorithm)
         // (truth: we need a set of disjoint intervals as an invariant for the following algorithm)
-        GenerateScopeStartToFunctionIntervalsReverseMap();
+        GenerateTokenScopeIntervalToUidReverseMap();
 
 
         Json::Value srgRoot(Json::objectValue);
         Json::Value srgRoot(Json::objectValue);
         // Order the reflection by SRG for convenience
         // Order the reflection by SRG for convenience
@@ -1044,8 +1044,8 @@ namespace AZ::ShaderCompiler
 
 
     void CodeReflection::AnalyzeOptionRanks() const
     void CodeReflection::AnalyzeOptionRanks() const
     {
     {
-        // make sure we have the function scope lookup cache ready
-        GenerateScopeStartToFunctionIntervalsReverseMap();
+        // make sure we have the scope lookup cache ready
+        GenerateTokenScopeIntervalToUidReverseMap();
         // loop over variables
         // loop over variables
         for (auto& [uid, varInfo, kindInfo] : m_ir->m_symbols.GetOrderedSymbolsOfSubType_3<VarInfo>())
         for (auto& [uid, varInfo, kindInfo] : m_ir->m_symbols.GetOrderedSymbolsOfSubType_3<VarInfo>())
         {
         {
@@ -1121,14 +1121,10 @@ namespace AZ::ShaderCompiler
         // figure out the scope at this token.
         // figure out the scope at this token.
         // theoretically should be something in the like of the body of another function,
         // theoretically should be something in the like of the body of another function,
         // or an anonymous block within another function.
         // or an anonymous block within another function.
-        auto intervalIter = FindIntervalInDisjointSet(m_functionIntervals, (ssize_t)callNode->start->getTokenIndex(),
-                                                      [](ssize_t key, auto& value)
-                                                      {
-                                                          return value.first.properlyContains({key, key});
-                                                      });
-        if (intervalIter != m_functionIntervals.cend())
+        auto interval = m_intervals.GetClosestIntervalSurrounding(callNode->start->getTokenIndex());
+        if (!interval.IsEmpty())
         {
         {
-            IdentifierUID encloser = intervalIter->second.second;
+            IdentifierUID encloser = m_intervalToUid[interval];
 
 
             // Because we are past the end of the semantic analysis,
             // Because we are past the end of the semantic analysis,
             // the scope tracker is registering the last seen scope (surely "/").
             // the scope tracker is registering the last seen scope (surely "/").
@@ -1147,6 +1143,7 @@ namespace AZ::ShaderCompiler
             else if (auto* maeExpr = As<AstMemberAccess*>(callNode->Expr))
             else if (auto* maeExpr = As<AstMemberAccess*>(callNode->Expr))
             {
             {
                 startupLookupScope = m_ir->m_sema.TypeofExpr(maeExpr->LHSExpr);
                 startupLookupScope = m_ir->m_sema.TypeofExpr(maeExpr->LHSExpr);
+                funcName = ExtractNameFromIdExpression(maeExpr->Member);
             }
             }
             IdAndKind* overload = m_ir->m_symbols.LookupSymbol(startupLookupScope, funcName);
             IdAndKind* overload = m_ir->m_symbols.LookupSymbol(startupLookupScope, funcName);
             if (!overload) // in case of function not found, we assume it's an intrinsic.
             if (!overload) // in case of function not found, we assume it's an intrinsic.
@@ -1193,7 +1190,7 @@ namespace AZ::ShaderCompiler
         }
         }
     }
     }
 
 
-    void CodeReflection::GenerateScopeStartToFunctionIntervalsReverseMap() const
+    void CodeReflection::GenerateTokenScopeIntervalToUidReverseMap() const
     {
     {
         if (m_functionIntervals.empty())
         if (m_functionIntervals.empty())
         {
         {
@@ -1204,7 +1201,11 @@ namespace AZ::ShaderCompiler
                     // the reason to choose .a as the key is so we can query using Infimum (sort of lower_bound)
                     // the reason to choose .a as the key is so we can query using Infimum (sort of lower_bound)
                     m_functionIntervals[interval.a] = std::make_pair(interval, uid);
                     m_functionIntervals[interval.a] = std::make_pair(interval, uid);
                 }
                 }
+                auto i = Interval<ssize_t>{interval.a, interval.b};
+                m_intervals.Add(i);
+                m_intervalToUid[i] = uid;
             }
             }
+            m_intervals.Seal();
         }
         }
     }
     }
 }
 }

+ 5 - 2
src/AzslcReflection.h

@@ -12,6 +12,7 @@
 namespace AZ::ShaderCompiler
 namespace AZ::ShaderCompiler
 {
 {
     using MapOfBeginToSpanAndUid = map<ssize_t, pair< misc::Interval, IdentifierUID> >;
     using MapOfBeginToSpanAndUid = map<ssize_t, pair< misc::Interval, IdentifierUID> >;
+    using MapOfIntervalToUid = map<Interval<ssize_t>, IdentifierUID>;
 
 
     struct CodeReflection : Backend
     struct CodeReflection : Backend
     {
     {
@@ -89,8 +90,10 @@ namespace AZ::ShaderCompiler
         void AnalyzeImpact(azslParser::FunctionCallExpressionContext* callNode, int& scoreAccumulator) const;
         void AnalyzeImpact(azslParser::FunctionCallExpressionContext* callNode, int& scoreAccumulator) const;
 
 
         //! Useful for static analysis on dependencies or option ranks
         //! Useful for static analysis on dependencies or option ranks
-        void GenerateScopeStartToFunctionIntervalsReverseMap() const;
-        mutable MapOfBeginToSpanAndUid m_functionIntervals; //< cache for the result of above function call
+        void GenerateTokenScopeIntervalToUidReverseMap() const;
+        mutable MapOfBeginToSpanAndUid m_functionIntervals;  //< only functions because they are guaranteed to be disjointed (largely simplifies queries)
+        mutable IntervalCollection<ssize_t> m_intervals;  //< augmented version with anonymous blocks (slower query)
+        mutable MapOfIntervalToUid m_intervalToUid;  //< side by side data since we don't want to weight the interval struct with a payload
 
 
         std::ostream& m_out;
         std::ostream& m_out;
     };
     };

+ 27 - 17
src/GenericUtils.h

@@ -600,11 +600,15 @@ namespace AZ
     {
     {
         using Interval = Interval<T>;
         using Interval = Interval<T>;
 
 
-        //! Construction from an iterable collection of Interval typed elements
-        template< typename Iterator >
-        IntervalCollection(Iterator&& begin, Iterator&& end)
-            : m_obfirsts(begin, end), m_oblasts(begin, end)
+        void Add(Interval i)
         {
         {
+            m_obfirsts.emplace_back(i);
+        }
+
+        //! doesn't respect RAII but for the sake of performance and convenience this is easier this way
+        void Seal()
+        {
+            m_oblasts = m_obfirsts;
             std::sort(m_obfirsts.begin(), m_obfirsts.end(), [](auto i1, auto i2)
             std::sort(m_obfirsts.begin(), m_obfirsts.end(), [](auto i1, auto i2)
                       {
                       {
                           return i1.a < i2.a;
                           return i1.a < i2.a;
@@ -613,27 +617,30 @@ namespace AZ
                       {
                       {
                           return i1.b < i2.b;
                           return i1.b < i2.b;
                       });
                       });
+            m_sealed = true;
         }
         }
 
 
         //! Retrieve the subset of intervals activated by a point (query)
         //! Retrieve the subset of intervals activated by a point (query)
-        set<Interval> GetIntervalsSurrounding(T query)
+        set<Interval> GetIntervalsSurrounding(T query) const
         {
         {
-            // construct the set of intervals starting before:
-            set<Interval> startBefore;
-            CopyIf(m_obfirsts.begin(), m_obfirsts.end(),
-                   [=](auto interv) { return interv.a <= query; },
-                   std::inserter(startBefore, startBefore.end()),
-                   CopyIfPolicy::InterruptAtFirstFalse);  // because the obfirsts vector is sorted
+            assert(m_sealed);
+            // find the "set" of intervals starting before:
+            auto firstsSubEnd = std::lower_bound(m_obfirsts.begin(), m_obfirsts.end(),
+                                                 query,
+                                                 [=](auto interv, T q) { return interv.a <= q; });
 
 
-            // construct the set of intervals ending after:
-            set<Interval> endAfter;
+            // find the "set" of intervals ending after:
+            static vector<Interval> endAfter;
+            endAfter.clear();
             CopyIf(m_oblasts.rbegin(), m_oblasts.rend(),  // reverse iteration
             CopyIf(m_oblasts.rbegin(), m_oblasts.rend(),  // reverse iteration
                    [=](auto interv) { return interv.b >= query; },
                    [=](auto interv) { return interv.b >= query; },
-                   std::inserter(endAfter, endAfter.end()),
+                   std::back_inserter(endAfter),
                    CopyIfPolicy::InterruptAtFirstFalse);
                    CopyIfPolicy::InterruptAtFirstFalse);
+            // for set_intersection to work, the less<> predicate has to work for both ranges
+            std::sort(endAfter.begin(), endAfter.end());
 
 
             set<Interval> result;
             set<Interval> result;
-            std::set_intersection(startBefore.begin(), startBefore.end(),
+            std::set_intersection(m_obfirsts.begin(), firstsSubEnd,
                                   endAfter.begin(), endAfter.end(),
                                   endAfter.begin(), endAfter.end(),
                                   std::inserter(result, result.end()));
                                   std::inserter(result, result.end()));
             return result;
             return result;
@@ -644,7 +651,7 @@ namespace AZ
         //! each overlapping interval is fully contained in the bigger one,
         //! each overlapping interval is fully contained in the bigger one,
         //! the closest start is guaranteed to be the most "leaf" interval.
         //! the closest start is guaranteed to be the most "leaf" interval.
         //! This is useful for scopes.
         //! This is useful for scopes.
-        Interval GetClosestIntervalSurrounding(T query)
+        Interval GetClosestIntervalSurrounding(T query) const
         {
         {
             auto bag = GetIntervalsSurrounding(query);
             auto bag = GetIntervalsSurrounding(query);
             return bag.empty() ? Interval{-1, -2} : *bag.rbegin();
             return bag.empty() ? Interval{-1, -2} : *bag.rbegin();
@@ -652,6 +659,7 @@ namespace AZ
 
 
         vector<Interval> m_obfirsts;  // ordered by "firsts"
         vector<Interval> m_obfirsts;  // ordered by "firsts"
         vector<Interval> m_oblasts;   // ordered by "lasts"
         vector<Interval> m_oblasts;   // ordered by "lasts"
+        bool m_sealed = false;
     };
     };
 
 
     template< typename Deduced >
     template< typename Deduced >
@@ -829,7 +837,9 @@ namespace AZ::Tests
         assert(!IsIn("hibou", std::initializer_list<const char*>{ "chouette", "jay" }));
         assert(!IsIn("hibou", std::initializer_list<const char*>{ "chouette", "jay" }));
 
 
         Interval<int> intvs[] = {{0,10}, {1,5}, {3,3}, {7,9}, {12,15}};
         Interval<int> intvs[] = {{0,10}, {1,5}, {3,3}, {7,9}, {12,15}};
-        IntervalCollection<int> ic{std::begin(intvs), std::end(intvs)};
+        IntervalCollection<int> ic;
+        std::for_each(std::begin(intvs), std::end(intvs), [&](auto i) {ic.Add(i); });
+        ic.Seal();
         assert(ic.GetClosestIntervalSurrounding(-3).IsEmpty());
         assert(ic.GetClosestIntervalSurrounding(-3).IsEmpty());
         assert((ic.GetClosestIntervalSurrounding(0) == Interval<int>{0,10}));
         assert((ic.GetClosestIntervalSurrounding(0) == Interval<int>{0,10}));
         assert((ic.GetClosestIntervalSurrounding(1) == Interval<int>{1,5}));
         assert((ic.GetClosestIntervalSurrounding(1) == Interval<int>{1,5}));