瀏覽代碼

First working order of code that can track method calls cost.
Now able to locate the symbols because the starting scope is correctly reconstructed, using the new IntervalCollection class which is able to support query for non-disjointed intervals, which is a more difficult case than what we had up to now. We still keep the previous map to functions because it's faster to query.

Signed-off-by: Vivien Oddou <[email protected]>

Vivien Oddou 2 年之前
父節點
當前提交
7042ebf23a
共有 3 個文件被更改,包括 44 次插入30 次删除
  1. 12 11
      src/AzslcReflection.cpp
  2. 5 2
      src/AzslcReflection.h
  3. 27 17
      src/GenericUtils.h

+ 12 - 11
src/AzslcReflection.cpp

@@ -914,7 +914,7 @@ namespace AZ::ShaderCompiler
 
         // Prepare a lookup acceleration data structure for reverse mapping tokens to scopes.
         // (truth: we need a set of disjoint intervals as an invariant for the following algorithm)
-        GenerateScopeStartToFunctionIntervalsReverseMap();
+        GenerateTokenScopeIntervalToUidReverseMap();
 
         Json::Value srgRoot(Json::objectValue);
         // Order the reflection by SRG for convenience
@@ -1044,8 +1044,8 @@ namespace AZ::ShaderCompiler
 
     void CodeReflection::AnalyzeOptionRanks() const
     {
-        // make sure we have the function scope lookup cache ready
-        GenerateScopeStartToFunctionIntervalsReverseMap();
+        // make sure we have the scope lookup cache ready
+        GenerateTokenScopeIntervalToUidReverseMap();
         // loop over variables
         for (auto& [uid, varInfo, kindInfo] : m_ir->m_symbols.GetOrderedSymbolsOfSubType_3<VarInfo>())
         {
@@ -1121,14 +1121,10 @@ namespace AZ::ShaderCompiler
         // figure out the scope at this token.
         // theoretically should be something in the like of the body of another function,
         // or an anonymous block within another function.
-        auto intervalIter = FindIntervalInDisjointSet(m_functionIntervals, (ssize_t)callNode->start->getTokenIndex(),
-                                                      [](ssize_t key, auto& value)
-                                                      {
-                                                          return value.first.properlyContains({key, key});
-                                                      });
-        if (intervalIter != m_functionIntervals.cend())
+        auto interval = m_intervals.GetClosestIntervalSurrounding(callNode->start->getTokenIndex());
+        if (!interval.IsEmpty())
         {
-            IdentifierUID encloser = intervalIter->second.second;
+            IdentifierUID encloser = m_intervalToUid[interval];
 
             // Because we are past the end of the semantic analysis,
             // the scope tracker is registering the last seen scope (surely "/").
@@ -1147,6 +1143,7 @@ namespace AZ::ShaderCompiler
             else if (auto* maeExpr = As<AstMemberAccess*>(callNode->Expr))
             {
                 startupLookupScope = m_ir->m_sema.TypeofExpr(maeExpr->LHSExpr);
+                funcName = ExtractNameFromIdExpression(maeExpr->Member);
             }
             IdAndKind* overload = m_ir->m_symbols.LookupSymbol(startupLookupScope, funcName);
             if (!overload) // in case of function not found, we assume it's an intrinsic.
@@ -1193,7 +1190,7 @@ namespace AZ::ShaderCompiler
         }
     }
 
-    void CodeReflection::GenerateScopeStartToFunctionIntervalsReverseMap() const
+    void CodeReflection::GenerateTokenScopeIntervalToUidReverseMap() const
     {
         if (m_functionIntervals.empty())
         {
@@ -1204,7 +1201,11 @@ namespace AZ::ShaderCompiler
                     // the reason to choose .a as the key is so we can query using Infimum (sort of lower_bound)
                     m_functionIntervals[interval.a] = std::make_pair(interval, uid);
                 }
+                auto i = Interval<ssize_t>{interval.a, interval.b};
+                m_intervals.Add(i);
+                m_intervalToUid[i] = uid;
             }
+            m_intervals.Seal();
         }
     }
 }

+ 5 - 2
src/AzslcReflection.h

@@ -12,6 +12,7 @@
 namespace AZ::ShaderCompiler
 {
     using MapOfBeginToSpanAndUid = map<ssize_t, pair< misc::Interval, IdentifierUID> >;
+    using MapOfIntervalToUid = map<Interval<ssize_t>, IdentifierUID>;
 
     struct CodeReflection : Backend
     {
@@ -89,8 +90,10 @@ namespace AZ::ShaderCompiler
         void AnalyzeImpact(azslParser::FunctionCallExpressionContext* callNode, int& scoreAccumulator) const;
 
         //! Useful for static analysis on dependencies or option ranks
-        void GenerateScopeStartToFunctionIntervalsReverseMap() const;
-        mutable MapOfBeginToSpanAndUid m_functionIntervals; //< cache for the result of above function call
+        void GenerateTokenScopeIntervalToUidReverseMap() const;
+        mutable MapOfBeginToSpanAndUid m_functionIntervals;  //< only functions because they are guaranteed to be disjointed (largely simplifies queries)
+        mutable IntervalCollection<ssize_t> m_intervals;  //< augmented version with anonymous blocks (slower query)
+        mutable MapOfIntervalToUid m_intervalToUid;  //< side by side data since we don't want to weight the interval struct with a payload
 
         std::ostream& m_out;
     };

+ 27 - 17
src/GenericUtils.h

@@ -600,11 +600,15 @@ namespace AZ
     {
         using Interval = Interval<T>;
 
-        //! Construction from an iterable collection of Interval typed elements
-        template< typename Iterator >
-        IntervalCollection(Iterator&& begin, Iterator&& end)
-            : m_obfirsts(begin, end), m_oblasts(begin, end)
+        void Add(Interval i)
         {
+            m_obfirsts.emplace_back(i);
+        }
+
+        //! doesn't respect RAII but for the sake of performance and convenience this is easier this way
+        void Seal()
+        {
+            m_oblasts = m_obfirsts;
             std::sort(m_obfirsts.begin(), m_obfirsts.end(), [](auto i1, auto i2)
                       {
                           return i1.a < i2.a;
@@ -613,27 +617,30 @@ namespace AZ
                       {
                           return i1.b < i2.b;
                       });
+            m_sealed = true;
         }
 
         //! Retrieve the subset of intervals activated by a point (query)
-        set<Interval> GetIntervalsSurrounding(T query)
+        set<Interval> GetIntervalsSurrounding(T query) const
         {
-            // construct the set of intervals starting before:
-            set<Interval> startBefore;
-            CopyIf(m_obfirsts.begin(), m_obfirsts.end(),
-                   [=](auto interv) { return interv.a <= query; },
-                   std::inserter(startBefore, startBefore.end()),
-                   CopyIfPolicy::InterruptAtFirstFalse);  // because the obfirsts vector is sorted
+            assert(m_sealed);
+            // find the "set" of intervals starting before:
+            auto firstsSubEnd = std::lower_bound(m_obfirsts.begin(), m_obfirsts.end(),
+                                                 query,
+                                                 [=](auto interv, T q) { return interv.a <= q; });
 
-            // construct the set of intervals ending after:
-            set<Interval> endAfter;
+            // find the "set" of intervals ending after:
+            static vector<Interval> endAfter;
+            endAfter.clear();
             CopyIf(m_oblasts.rbegin(), m_oblasts.rend(),  // reverse iteration
                    [=](auto interv) { return interv.b >= query; },
-                   std::inserter(endAfter, endAfter.end()),
+                   std::back_inserter(endAfter),
                    CopyIfPolicy::InterruptAtFirstFalse);
+            // for set_intersection to work, the less<> predicate has to work for both ranges
+            std::sort(endAfter.begin(), endAfter.end());
 
             set<Interval> result;
-            std::set_intersection(startBefore.begin(), startBefore.end(),
+            std::set_intersection(m_obfirsts.begin(), firstsSubEnd,
                                   endAfter.begin(), endAfter.end(),
                                   std::inserter(result, result.end()));
             return result;
@@ -644,7 +651,7 @@ namespace AZ
         //! each overlapping interval is fully contained in the bigger one,
         //! the closest start is guaranteed to be the most "leaf" interval.
         //! This is useful for scopes.
-        Interval GetClosestIntervalSurrounding(T query)
+        Interval GetClosestIntervalSurrounding(T query) const
         {
             auto bag = GetIntervalsSurrounding(query);
             return bag.empty() ? Interval{-1, -2} : *bag.rbegin();
@@ -652,6 +659,7 @@ namespace AZ
 
         vector<Interval> m_obfirsts;  // ordered by "firsts"
         vector<Interval> m_oblasts;   // ordered by "lasts"
+        bool m_sealed = false;
     };
 
     template< typename Deduced >
@@ -829,7 +837,9 @@ namespace AZ::Tests
         assert(!IsIn("hibou", std::initializer_list<const char*>{ "chouette", "jay" }));
 
         Interval<int> intvs[] = {{0,10}, {1,5}, {3,3}, {7,9}, {12,15}};
-        IntervalCollection<int> ic{std::begin(intvs), std::end(intvs)};
+        IntervalCollection<int> ic;
+        std::for_each(std::begin(intvs), std::end(intvs), [&](auto i) {ic.Add(i); });
+        ic.Seal();
         assert(ic.GetClosestIntervalSurrounding(-3).IsEmpty());
         assert((ic.GetClosestIntervalSurrounding(0) == Interval<int>{0,10}));
         assert((ic.GetClosestIntervalSurrounding(1) == Interval<int>{1,5}));