2
0
Эх сурвалжийг харах

Refactor the option rank analysis to make it easier to tread with named function for each action.
+ prepare a fix for the <failed> type analysis of the "method call solver".
This fail is due to the fact that we are past the semantic analysis, so we don't have proper scope tracking.
The lookup cannot work if we don't provide a starting scope.
That starting scope is reconstructed artificially using the scopes/token map collection.
Unfortunately, the fast lookup is now using an intermediate map that filters by function only.
If we don't include the unnamed blocks, Lookup will systematically fail for any object within curly brace or an if block, for block etc.
To solve that problem, I prepared a "non disjointed" interval query system.
It's an unfortunate change from Log(n) by query to O(N) by query though. Also we have a memory fest since these are node containers. We might want to consider Howard Hinnant stack allocator soon after.

Signed-off-by: Vivien Oddou <[email protected]>

Vivien Oddou 2 жил өмнө
parent
commit
661d01ffb6

+ 128 - 89
src/AzslcReflection.cpp

@@ -861,10 +861,10 @@ namespace AZ::ShaderCompiler
             assert(uid == seenat.m_referredDefinition);
             // careful of the invariant: distinct intervals. (can't support functions nested in functions nor imbricated block scopes)
             // ok for now because AZSL/HLSL don't have lambdas
-            auto intervalIter = FindInterval(scopes, seenat.m_where.m_focusedTokenId, [](ssize_t key, auto& value)
-                                             {
-                                                 return value.first.properlyContains({key, key});
-                                             });
+            auto intervalIter = FindIntervalInDisjointSet(scopes, seenat.m_where.m_focusedTokenId, [](ssize_t key, auto& value)
+                                                          {
+                                                              return value.first.properlyContains({key, key});
+                                                          });
             if (intervalIter != scopes.cend())
             {
                 const IdentifierUID& encloser = intervalIter->second.second;
@@ -1001,6 +1001,47 @@ namespace AZ::ShaderCompiler
         m_out << srgRoot;
     }
 
+    // Helper routine for option rank analysis
+    static int GuesstimateIntrinsicFunctionCost(string_view funcName)
+    {
+        if (IsOneOf(funcName, "CallShader", "TraceRay"))
+        { // non measurable but assumed high
+            return 100;
+        }
+        else if (IsOneOf(funcName, "Sample", "Load", "InterlockedCompareStore", "InterlockedCompareExchange", "InterlockedExchange", "Append"))
+        { // memory access, locked or not, will have high latency
+            return 10;
+        }
+        else
+        { // unlisted intrinsics like lerp, log2, cos, distance.. will default to a cost of 1.
+            return 1;
+        }
+    }
+
+    // Helper routine for option rank analysis. When picking AN overload is more useful than forfeiting.
+    // The function GetConcreteFunctionThatMatchesArgumentList forfeits when the overloadset contains
+    // strictly more than 1 concrete function with the queried arity. In our case, we prefer to just pick any.
+    static IdentifierUID PickAnyOverloadThatMatchesArgCount(IntermediateRepresentation* ir,
+                                                            azslParser::FunctionCallExpressionContext* callNode,
+                                                            KindInfo& overload)
+    {
+        IdentifierUID concrete;
+        size_t numArgs = NumArgs(callNode);
+        overload.GetSubAs<OverloadSetInfo>()->AnyOf(
+            [&](IdentifierUID const& uid)
+            {
+                auto* concreteFcInfo = ir->GetSymbolSubAs<FunctionInfo>(uid.GetName());
+                size_t numParams = concreteFcInfo->GetParameters(true).size();
+                if (numParams == numArgs)
+                {
+                    concrete = uid;  // we write the result through reference capture (not clean but convenient)
+                    return true;
+                }
+                return false;
+            });
+        return concrete;
+    }
+
     void CodeReflection::AnalyzeOptionRanks() const
     {
         // make sure we have the function scope lookup cache ready
@@ -1030,7 +1071,7 @@ namespace AZ::ShaderCompiler
         ParserRuleContext* node = m_ir->m_tokenMap.GetNode(location.m_focusedTokenId);
         // go up tree to meet a block node that has visitable depth:
         // can be any of if/for/while/switch
-        //  4 is an arbitrary depth, enough to search up things like `for (a, b<(ref+1), c)` binary op->braces->cmp expr->cond->for
+        //  4 is an arbitrary depth, enough to search up things like `for (a, b<(ref+1), c)` binaryop->braces->cmpexpr->cond->for
         if (auto* whileNode = DeepParentAs<azslParser::WhileStatementContext*>(node->parent, 3))
         {
             node = whileNode->embeddedStatement();
@@ -1058,89 +1099,8 @@ namespace AZ::ShaderCompiler
         {
             if (auto* callNode = As<azslParser::FunctionCallExpressionContext*>(c))
             {
-                // get function score in FunctionInfo if cached, compute it and store if not.
-                // to access the function symbol info we need the current scope, the function call name and perform a lookup.
-                auto intervalIter = FindInterval(m_functionIntervals, (ssize_t)callNode->start->getTokenIndex(),
-                                                 [](ssize_t key, auto& value)
-                                                 {
-                                                     return value.first.properlyContains({key, key});
-                                                 });
-                if (intervalIter != m_functionIntervals.cend())
-                {
-                    const IdentifierUID& encloser = intervalIter->second.second;
-                    // lookup function at AST node `callNode` from scope `encloser`
-                    if (auto* idExpr = As<azslParser::IdentifierExpressionContext*>(callNode->Expr))
-                    {
-                        UnqualifiedName funcName = ExtractNameFromIdExpression(idExpr->idExpression());
-                        IdAndKind* overload = m_ir->m_symbols.LookupSymbol(encloser.GetName(), funcName);
-                        if (!overload) // in case of function not found, we assume it's an intrinsic.
-                        {
-                            if (IsOneOf(funcName, "CallShader", "TraceRay"))
-                            { // non measurable but assumed high
-                                scoreAccumulator += 50;
-                            }
-                            else if (IsOneOf(funcName, "InterlockedCompareStore", "InterlockedCompareExchange", "InterlockedExchange", "Append"))
-                            { // hardware locked memory ops, high weight
-                                scoreAccumulator += 10;
-                            }
-                            else if (IsOneOf(funcName, "Sample", "Load"))
-                            { // memory access is weighted in between
-                                scoreAccumulator += 5;
-                            }
-                            else
-                            { // unlisted intrinsics like lerp, log2, cos, distance.. will default to a cost of 1.
-                                scoreAccumulator += 1;
-                            }
-                        }
-                        else
-                        {
-                            azslParser::ArgumentListContext* args = GetArgumentListIfBelongsToFunctionCall(callNode);
-                            IdAndKind* symbolMeantUnderCallNode = m_ir->m_sema.ResolveOverload(overload, args);
-                            IdentifierUID concrete;
-                            if (!symbolMeantUnderCallNode || m_ir->GetKind(symbolMeantUnderCallNode->first) == Kind::OverloadSet)
-                            { // in case of strict selection failure, run a fuzzy select
-                                size_t numArgs = NumArgs(callNode);
-                                overload->second.GetSubAs<OverloadSetInfo>()->AnyOf(
-                                    [&](IdentifierUID const& uid)
-                                    {
-                                        auto* concreteFcInfo = m_ir->GetSymbolSubAs<FunctionInfo>(uid.GetName());
-                                        size_t numParams = concreteFcInfo->GetParameters(true).size();
-                                        if (numParams == numArgs)
-                                        {
-                                            concrete = uid;
-                                            return true;
-                                        }
-                                        return false;
-                                    }
-                                );
-                                // if still not enough to get a fix, it might be an ill-formed input. prefer to forfeit
-                            }
-                            else
-                            {
-                                concrete = symbolMeantUnderCallNode->first;
-                            }
-                            auto* funcInfo = m_ir->GetSymbolSubAs<FunctionInfo>(concrete.GetName());
-                            if (funcInfo->m_costScore == -1)
-                            {
-                                funcInfo->m_costScore = 0;
-                                using AstFDef = azslParser::HlslFunctionDefinitionContext;
-                                AnalyzeImpact(polymorphic_downcast<AstFDef*>(funcInfo->m_defNode->parent)->block(),
-                                              funcInfo->m_costScore);  // recurse and cache if not already done
-                            }
-                            scoreAccumulator += funcInfo->m_costScore;
-                        }
-                    }
-                    // other cases forfeited for now, but that would at least be braces (f)() or MAE x.m()
-                }
-                else // no interval found
-                {
-                    // function calls outside of function bodies can appear in an initializer:
-                    //    int g_a = MakeA();  // global init
-                    //    class C { int m_a = CompA();  // constructor init (invalid AZSL/HLSL)
-                    //    class D { void Method(int a_a = DefaultA());  // default parameter value
-                    // in any case, extracting the scope is impossible with this system.
-                    // we forfeit evaluation of a score
-                }
+                // branch into an overload specialized for function lookup:
+                AnalyzeImpact(callNode, scoreAccumulator);
             }
             else if (auto* node = As<ParserRuleContext*>(c))
             {
@@ -1149,8 +1109,87 @@ namespace AZ::ShaderCompiler
             if (auto* leaf = As<tree::TerminalNode*>(c))
             {
                 // determine cost by number of full expressions separated by semicolon
-                scoreAccumulator += leaf->getSymbol()->getType() == azslLexer::Semi;
+                scoreAccumulator += leaf->getSymbol()->getType() == azslLexer::Semi;  // bool as 0 or 1 trick
+            }
+        }
+    }
+
+    void CodeReflection::AnalyzeImpact(azslParser::FunctionCallExpressionContext* callNode, int& scoreAccumulator) const
+    {
+        // to access the function symbol info we need the current scope, the function call name and perform a lookup.
+
+        // figure out the scope at this token.
+        // theoretically should be something in the like of the body of another function,
+        // or an anonymous block within another function.
+        auto intervalIter = FindIntervalInDisjointSet(m_functionIntervals, (ssize_t)callNode->start->getTokenIndex(),
+                                                      [](ssize_t key, auto& value)
+                                                      {
+                                                          return value.first.properlyContains({key, key});
+                                                      });
+        if (intervalIter != m_functionIntervals.cend())
+        {
+            IdentifierUID encloser = intervalIter->second.second;
+
+            // Because we are past the end of the semantic analysis,
+            // the scope tracker is registering the last seen scope (surely "/").
+            // This is a stateful side-effect system unfortunately, and since we'll call
+            // some feature of the semantic orchestrator (like TypeofExpr) we need to hack
+            // the scope tracker:
+            m_ir->m_sema.m_scope->m_currentScopePath = encloser.GetName();
+            m_ir->m_sema.m_scope->UpdateCurScopeUID();
+
+            QualifiedName startupLookupScope = encloser.GetName();
+            UnqualifiedName funcName;
+            if (auto* idExpr = As<azslParser::IdentifierExpressionContext*>(callNode->Expr))
+            {
+                funcName = ExtractNameFromIdExpression(idExpr->idExpression());
+            }
+            else if (auto* maeExpr = As<AstMemberAccess*>(callNode->Expr))
+            {
+                startupLookupScope = m_ir->m_sema.TypeofExpr(maeExpr->LHSExpr);
+            }
+            IdAndKind* overload = m_ir->m_symbols.LookupSymbol(startupLookupScope, funcName);
+            if (!overload) // in case of function not found, we assume it's an intrinsic.
+            {
+                scoreAccumulator += GuesstimateIntrinsicFunctionCost(funcName);
             }
+            else
+            {
+                azslParser::ArgumentListContext* args = GetArgumentListIfBelongsToFunctionCall(callNode);
+                IdAndKind* symbolMeantUnderCallNode = m_ir->m_sema.ResolveOverload(overload, args);
+                IdentifierUID concrete;
+                if (!symbolMeantUnderCallNode || m_ir->GetKind(symbolMeantUnderCallNode->first) == Kind::OverloadSet)
+                {   // in case of strict selection failure, run a fuzzy select
+                    concrete = PickAnyOverloadThatMatchesArgCount(m_ir, callNode, overload->second);
+                    // if still not enough to get a fix (concrete=={}), it might be an ill-formed input. prefer to forfeit
+                }
+                else
+                {
+                    concrete = symbolMeantUnderCallNode->first;
+                }
+
+                if (auto* funcInfo = m_ir->GetSymbolSubAs<FunctionInfo>(concrete.GetName()))
+                {
+                    if (funcInfo->m_costScore == -1)  // cost not yet discovered for this function
+                    {
+                        funcInfo->m_costScore = 0;
+                        using AstFDef = azslParser::HlslFunctionDefinitionContext;
+                        AnalyzeImpact(polymorphic_downcast<AstFDef*>(funcInfo->m_defNode->parent)->block(),
+                                      funcInfo->m_costScore);  // recurse and cache
+                    }
+                    scoreAccumulator += funcInfo->m_costScore;
+                }
+            }
+            // other cases forfeited for now, but that would at least include things like eg braces (f)()
+        }
+        else // no interval found
+        {
+            // function calls outside of function bodies can appear in an initializer:
+            //    int g_a = MakeA();  // global init
+            //    class C { int m_a = CompA();  // constructor init (invalid AZSL/HLSL)
+            //    class D { void Method(int a_a = DefaultA());  // default parameter value
+            // in any case, extracting the scope is impossible with this system.
+            // we forfeit evaluation of a score
         }
     }
 

+ 3 - 0
src/AzslcReflection.h

@@ -85,6 +85,9 @@ namespace AZ::ShaderCompiler
         // Recursive internal detail version
         void AnalyzeImpact(ParserRuleContext* astNode, int& scoreAccumulator) const;
 
+        // Function-call specific
+        void AnalyzeImpact(azslParser::FunctionCallExpressionContext* callNode, int& scoreAccumulator) const;
+
         //! Useful for static analysis on dependencies or option ranks
         void GenerateScopeStartToFunctionIntervalsReverseMap() const;
         mutable MapOfBeginToSpanAndUid m_functionIntervals; //< cache for the result of above function call

+ 155 - 47
src/GenericUtils.h

@@ -24,9 +24,9 @@ namespace AZ
     {
         using runtime_error::runtime_error;
     };
-    // Type-heterogeneity-preserving multi pointer object single visitor.
-    // Returns whatever the passed functor would.
-    // Throws if all passed objects are null.
+    //! Type-heterogeneity-preserving multi pointer object single visitor.
+    //! Returns whatever the passed functor would.
+    //! Throws if all passed objects are null.
     template <typename Lambda, typename T>
     invoke_result_t<Lambda, T*> VisitFirstNonNull(Lambda functor, T* object) noexcept(false)
     {
@@ -50,9 +50,9 @@ namespace AZ
         }
     }
 
-    // Create substring views of views. Works like python slicing operator [n:m] with limited modulo semantics.
-    // what I ultimately desire is the range v.3 feature eg `letters[{2,end-2}]`
-    // http://ericniebler.com/2014/12/07/a-slice-of-python-in-c/
+    //! Create substring views of views. Works like python slicing operator [n:m] with limited modulo semantics.
+    //! what I ultimately desire is the range v.3 feature eg `letters[{2,end-2}]`
+    //! http://ericniebler.com/2014/12/07/a-slice-of-python-in-c/
     inline constexpr string_view Slice(const string_view& in, int64_t st, int64_t end)
     {
         auto inSSize = (int64_t)in.size();
@@ -107,8 +107,8 @@ namespace AZ
     //https://developercommunity.visualstudio.com/content/problem/275141/c2131-expression-did-not-evaluate-to-a-constant-fo.html
     }
 
-    // ability to create size_t literals
-    // waiting for Working Group to get their stuff together https://groups.google.com/a/isocpp.org/forum/#!topic/std-proposals/tGoPjUeHlKo
+    //! ability to create size_t literals
+    //! waiting for Working Group to get their stuff together https://groups.google.com/a/isocpp.org/forum/#!topic/std-proposals/tGoPjUeHlKo
     inline constexpr std::size_t operator ""_sz(unsigned long long n)
     {
         return n;
@@ -145,7 +145,7 @@ namespace AZ
         return fileName.substr(0, lastIndex);
     }
 
-    // debug-build asserted dyn_cast, otherwise, release-build static_cast (idea from boost library)
+    //! debug-build asserted dyn_cast, otherwise, release-build static_cast (idea from boost library)
     template <typename To, typename From>
     To polymorphic_downcast(From instance)
     {
@@ -157,7 +157,7 @@ namespace AZ
         return static_cast<To>(instance);
     }
 
-    /// surround a string with a prefix and a suffix
+    //! surround a string with a prefix and a suffix
     inline string Decorate(string_view prefix, string_view body, string_view suffix)
     {
         std::stringstream ss;
@@ -167,13 +167,13 @@ namespace AZ
         return ss.str();
     }
 
-    /// 2 arguments version in case both sides are the same
+    //! 2 arguments version in case both sides are the same
     inline string Decorate(string_view prefixAndSuffix, string_view body)
     {
         return Decorate(prefixAndSuffix, body, prefixAndSuffix);
     }
 
-    /// reverse the effect of a symmetrical decoration
+    //! reverse the effect of a symmetrical decoration
     inline string_view Undecorate(string_view decoration, string_view body)
     {
         auto indexStart = StartsWith(body, decoration) ? decoration.length() : 0;
@@ -181,7 +181,7 @@ namespace AZ
         return Slice(body, indexStart, indexEnd);
     }
 
-    // Erase-Remove algorithm which removes all whitespaces from a string.
+    //! Erase-Remove algorithm which removes all whitespaces from a string.
     inline string RemoveWhitespaces(string haystack)
     {
         haystack.erase(std::remove_if(haystack.begin(), haystack.end(), [](unsigned char c) {return std::isspace(c); }), haystack.end());
@@ -193,14 +193,14 @@ namespace AZ
         return std::all_of(s.begin(), s.end(), [&](char c) { return std::isspace(c); });
     }
 
-    /// tells whether a position in a string is surrounded by round braces
-    /// e.g. true  for arguments {"a(b)", 2}
-    /// e.g. true  for arguments {"a()", 1}  by convention
-    /// e.g. false for arguments {"a()", 2}  by convention
-    /// e.g. false for arguments {"a(b)", 0}
-    /// e.g. false for arguments {"a(b)c", 4}
-    /// e.g. false for arguments {"a(b)c(d)", 4}
-    /// e.g. true  for arguments {"a((b)c(d))", 5}
+    //! tells whether a position in a string is surrounded by round braces
+    //! e.g. true  for arguments {"a(b)", 2}
+    //! e.g. true  for arguments {"a()", 1}  by convention
+    //! e.g. false for arguments {"a()", 2}  by convention
+    //! e.g. false for arguments {"a(b)", 0}
+    //! e.g. false for arguments {"a(b)c", 4}
+    //! e.g. false for arguments {"a(b)c(d)", 4}
+    //! e.g. true  for arguments {"a((b)c(d))", 5}
     inline bool WithinMatchedParentheses(string_view haystack, size_t charPosition)
     {
         const auto hayLen = haystack.length();
@@ -215,8 +215,8 @@ namespace AZ
         return nesting > 0;
     }
 
-    /// replace all occurrences of substring `sub` with substring `to` within haystack.
-    ///  e.g: Replace("aaa#aaa", "#", "_") gives-> "aaa_aaa"
+    //! replace all occurrences of substring `sub` with substring `to` within haystack.
+    //!  e.g: Replace("aaa#aaa", "#", "_") gives-> "aaa_aaa"
     inline string Replace(string haystack, string_view sub, string_view to)
     {
         decltype(sub.length()) pos = 0;
@@ -230,7 +230,7 @@ namespace AZ
         return haystack;
     }
 
-    // this one is inspired by the docopt utilities. trims whitespace by default, but can be used to trim quotes.
+    //! this one is inspired by the docopt utilities. trims whitespace by default, but can be used to trim quotes.
     constexpr inline string_view Trim(string_view haystack, string_view toTrim = " \t\n")
     {
         const auto strEnd = haystack.find_last_not_of(toTrim);
@@ -268,34 +268,56 @@ namespace AZ
         return std::find_if(begin, end, p) != end;
     }
 
-    /// argument in rangeV3-style version:
+    //! argument in rangeV3-style version:
     template< typename Container >
     string Join(const Container& c, string_view separator = "")
     {
         return Join(c.begin(), c.end(), separator);
     }
 
-    /// argument in rangeV3-style version:
+    //! argument in rangeV3-style version:
     template< typename Container, typename Predicate >
     bool Contains(const Container& c, Predicate p)
     {
         return Contains(c.begin(), c.end(), p);
     }
 
-    /// closest possible form of python's `in` keyword
+    //! closest possible form of python's `in` keyword
     template< typename Element, typename Container >
     bool IsIn(const Element& element, const Container& container)
     {
         return std::find(container.begin(), container.end(), element) != container.end();
     }
 
-    /// generate a new container with copy-and-mutated elements
+    //! generate a new container with copy-and-mutated elements
     template< typename Container, typename ContainerOut, typename Functor >
     void TransformCopy(const Container& in, ContainerOut& out, Functor mutator)
     {
         std::transform(in.begin(), in.end(), std::back_inserter(out), mutator);
     }
 
+    enum class CopyIfPolicy { ForAll, InterruptAtFirstFalse };
+
+    //! inserts elements into the output iterator if they pass a predicate
+    template< typename InputIterator, typename Predicate, typename OutputIterator >
+    void CopyIf(InputIterator begin, InputIterator end,
+                Predicate pred,
+                OutputIterator out,
+                CopyIfPolicy policy)
+    {
+        for (auto it = begin; it != end; ++it)
+        {
+            if (pred(*it))
+            {
+                *out = *it;
+            }
+            else if (policy == CopyIfPolicy::InterruptAtFirstFalse)
+            {
+                break;
+            }
+        }
+    }
+
     inline string Unescape(string_view escapedText)
     {
         std::stringstream out;
@@ -531,7 +553,7 @@ namespace AZ
 #endif
     }
 
-    /// Log(N) find immediate lower element query in map-keys
+    //! Log(N) query to find the first immediately lower or equal element in a map's keys
     template< typename T, typename U >
     auto Infimum(map<T, U> const& ctr, T query)
     {
@@ -539,35 +561,106 @@ namespace AZ
         return it == ctr.begin() ? ctr.cend() : --it;
     }
 
-    /// Log(N) disjointed segments belong query
-    /// You can represent segments as you wish, as long as:
-    ///   you provide the predicate to determine belonging.
-    ///   map-keys are segment start points.
-    ///   segments don't overlap.
-    /// returns: iterator to found interval key, or cend()
+    //! Log(N) disjointed segments belong query
+    //! You can represent segments as you wish, as long as:
+    //!   you provide the predicate to determine belonging.
+    //!   map-keys are segment start points.
+    //!   segments don't overlap.
+    //! returns: iterator to found interval key, or cend()
     template< typename T, typename U, typename IntervalCheckPredicate >
-    auto FindInterval(const map<T, U>& ctr, const T& query, IntervalCheckPredicate&& isInIntervalPredicate)
+    auto FindIntervalInDisjointSet(const map<T, U>& ctr, const T& query, IntervalCheckPredicate&& isInIntervalPredicate)
     {
         auto inf = Infimum(ctr, query);
-        return inf == ctr.end() ? ctr.cend() : (isInIntervalPredicate(query, inf->second) ? inf : ctr.cend());
+        bool isInInterval = inf != ctr.cend() && isInIntervalPredicate(query, inf->second);
+        return isInInterval ? inf : ctr.cend();
     }
 
-    /// Log(N) disjointed segments belong query
-    /// segments are represented by their start points in the key, and last point in values. segments can't overlap.
-    /// returns: iterator to found interval key, or cend()
+    //! Log(N) disjointed segments belong query
+    //! segments are represented by their start points in the key, and last point in values. segments can't overlap.
+    //! returns: iterator to found interval key, or cend()
     template< typename T, typename U>
-    auto FindInterval(const map<T, U>& ctr, const T& query)
+    auto FindIntervalInDisjointSet(const map<T, U>& ctr, const T& query)
     {
-        return FindInterval(ctr, query, [](T q, U last) {return q <= last; });
+        return FindIntervalInDisjointSet(ctr, query, [](T q, U last) {return q <= last; });
     }
 
+    template< typename T >
+    struct Interval
+    {
+        bool IsEmpty() const { return b < a; }
+        bool operator== (Interval const& rhs) const { return a == rhs.a && b == rhs.b; }
+        bool operator< (Interval const& rhs) const { return a < rhs.a || (a == rhs.a && b < rhs.b); }
+        T a = (T)0;
+        T b = (T)-1;
+    };
+
+    //! In case of potential overlaps (not disjointed), this structure can support "is in" queries
+    template< typename T >
+    struct IntervalCollection
+    {
+        using Interval = Interval<T>;
+
+        //! Construction from an iterable collection of Interval typed elements
+        template< typename Iterator >
+        IntervalCollection(Iterator&& begin, Iterator&& end)
+            : m_obfirsts(begin, end), m_oblasts(begin, end)
+        {
+            std::sort(m_obfirsts.begin(), m_obfirsts.end(), [](auto i1, auto i2)
+                      {
+                          return i1.a < i2.a;
+                      });
+            std::sort(m_oblasts.begin(), m_oblasts.end(), [](auto i1, auto i2)
+                      {
+                          return i1.b < i2.b;
+                      });
+        }
+
+        //! Retrieve the subset of intervals activated by a point (query)
+        set<Interval> GetIntervalsSurrounding(T query)
+        {
+            // construct the set of intervals starting before:
+            set<Interval> startBefore;
+            CopyIf(m_obfirsts.begin(), m_obfirsts.end(),
+                   [=](auto interv) { return interv.a <= query; },
+                   std::inserter(startBefore, startBefore.end()),
+                   CopyIfPolicy::InterruptAtFirstFalse);  // because the obfirsts vector is sorted
+
+            // construct the set of intervals ending after:
+            set<Interval> endAfter;
+            CopyIf(m_oblasts.rbegin(), m_oblasts.rend(),  // reverse iteration
+                   [=](auto interv) { return interv.b >= query; },
+                   std::inserter(endAfter, endAfter.end()),
+                   CopyIfPolicy::InterruptAtFirstFalse);
+
+            set<Interval> result;
+            std::set_intersection(startBefore.begin(), startBefore.end(),
+                                  endAfter.begin(), endAfter.end(),
+                                  std::inserter(result, result.end()));
+            return result;
+        }
+
+        //! Get the interval surrounding query that has the closest start point to query.
+        //! In case of an interval collection representing a tree, that is,
+        //! each overlapping interval is fully contained in the bigger one,
+        //! the closest start is guaranteed to be the most "leaf" interval.
+        //! This is useful for scopes.
+        Interval GetClosestIntervalSurrounding(T query)
+        {
+            auto bag = GetIntervalsSurrounding(query);
+            return bag.empty() ? Interval{-1, -2} : *bag.rbegin();
+        }
+
+        vector<Interval> m_obfirsts;  // ordered by "firsts"
+        vector<Interval> m_oblasts;   // ordered by "lasts"
+    };
+
     template< typename Deduced >
     decltype(auto) CastToRValueReference(Deduced&& value)
     {
         return static_cast<std::remove_reference_t<Deduced>&&>(value);
     }
 
-    /// add a missing operator for convenience and shortness of code
+    //! add a missing operator for convenience and shortness of code
     inline bool operator == (string_view lhs, char rhs)
     {
         return lhs.length() == 1 && lhs[0] == rhs;
@@ -713,10 +806,10 @@ namespace AZ::Tests
             assert(yellow == intervals.cend());
             auto larger_than_all = Infimum(intervals, 15);
             assert(larger_than_all->first == 8);
-            assert(FindInterval(intervals, 4) != intervals.cend());
-            assert(FindInterval(intervals, 6) == intervals.cend());
-            assert(FindInterval(intervals, 8) != intervals.cend());
-            assert(FindInterval(intervals, 1) == intervals.cend());
+            assert(FindIntervalInDisjointSet(intervals, 4) != intervals.cend());
+            assert(FindIntervalInDisjointSet(intervals, 6) == intervals.cend());
+            assert(FindIntervalInDisjointSet(intervals, 8) != intervals.cend());
+            assert(FindIntervalInDisjointSet(intervals, 1) == intervals.cend());
 
             auto high = Infimum(intervals, 20);
             assert(high->first == 8);
@@ -734,6 +827,21 @@ namespace AZ::Tests
 
         assert(IsIn("hibou", std::initializer_list<const char*>{ "chouette", "hibou", "jay" }));
         assert(!IsIn("hibou", std::initializer_list<const char*>{ "chouette", "jay" }));
+
+        Interval<int> intvs[] = {{0,10}, {1,5}, {3,3}, {7,9}, {12,15}};
+        IntervalCollection<int> ic{std::begin(intvs), std::end(intvs)};
+        assert(ic.GetClosestIntervalSurrounding(-3).IsEmpty());
+        assert((ic.GetClosestIntervalSurrounding(0) == Interval<int>{0,10}));
+        assert((ic.GetClosestIntervalSurrounding(1) == Interval<int>{1,5}));
+        assert((ic.GetClosestIntervalSurrounding(3) == Interval<int>{3,3}));
+        assert((ic.GetClosestIntervalSurrounding(4) == Interval<int>{1,5}));
+        assert((ic.GetClosestIntervalSurrounding(6) == Interval<int>{0,10}));
+        assert((ic.GetClosestIntervalSurrounding(5) == Interval<int>{1,5}));
+        assert((ic.GetClosestIntervalSurrounding(7) == Interval<int>{7,9}));
+        assert((ic.GetClosestIntervalSurrounding(9) == Interval<int>{7,9}));
+        assert(ic.GetClosestIntervalSurrounding(11).IsEmpty());
+        assert((ic.GetClosestIntervalSurrounding(13) == Interval<int>{12,15}));
+        assert(ic.GetClosestIntervalSurrounding(16).IsEmpty());
     }
 }
 #endif

+ 110 - 6
tests/Semantic/typeof-keyword.azsl

@@ -19,7 +19,7 @@ top gettop();
 
 class A
 {
-	int a;
+    int a;
 };
 
 class B : A
@@ -83,16 +83,14 @@ void h()
     //   NumericConstructorExpression  float2(0,0)
     //   LiteralExpression             42
     //   CommaExpression               X, Y
-    // not supported:
     //   PostfixUnaryExpression        i++
     //   PrefixUnaryExpression         ++i
     //   BinaryExpression              i + j
-    // e.g. typeof(1 + 3) = <fail>
 
-     // mathematics
+    // mathematics
     __azslc_print_message("@check predicate ");
     __azslc_print_symbol(typeof(1 + 3), __azslc_prtsym_least_qualified);
-    __azslc_print_message(" == '<fail>'\n");
+    __azslc_print_message(" == 'int'\n");
 
     // literals
     __azslc_print_message("@check predicate ");
@@ -186,7 +184,7 @@ void h()
     __azslc_print_symbol(typeof(top::inner), __azslc_prtsym_fully_qualified);
     __azslc_print_message(" == '/top/inner'\n");
 
-	// class inheritance: parent member access using scope resolution operator
+    // class inheritance: parent member access using scope resolution operator
     __azslc_print_message("@check predicate ");
     __azslc_print_symbol(typeof(B::a), __azslc_prtsym_least_qualified);
     __azslc_print_message(" == 'int'\n");
@@ -445,4 +443,110 @@ void h()
     __azslc_print_message("@check predicate ");
     __azslc_print_symbol(typeof(INTVAR, INTVAR, DOUBLEVAR), __azslc_prtsym_least_qualified);
     __azslc_print_message(" == 'double'\n");
+
+    // binary arithmetic expression
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof(1 + 1), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'int'\n");
+
+    // with float promotion
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof(1 + 1.f), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'float'\n");
+
+    // with half promotion
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof(1 + 1.h), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'half'\n");
+
+    // half and float to float promotion
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof(1.f + 1.h), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'float'\n");
+
+    // int16_t to double promotion
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof(1.l + int16_t(1)), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'double'\n");
+
+    // bool binary
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof(1.l && int16_t(1)), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'bool'\n");
+
+    // lookedup
+    double d;
+    int64_t i64;
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof(d || i64), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'bool'\n");
+
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof(d - i64), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'double'\n");
+
+    // vector scalar
+    float4 f4;
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof(2.f * f4), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'float4'\n");
+
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof(f4 * 2.f), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'float4'\n");
+
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof(d * f4), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'double4'\n");
+
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof(f4 * d), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'double4'\n");
+
+    // matrix scalar
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof(float() * float3x2()), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'float3x2'\n");
+
+    // matrix scalar with base type promotion
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof(half3x2() * double()), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'double3x2'\n");
+
+    // truncations
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof(float3() * float2()), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'float2'\n");
+
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof(float4x4() * float2x2()), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'float2x2'\n");
+
+    // truncate & upcast
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof(float4x4() * double2x2()), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'double2x2'\n");
+
+    // through alias
+    typealias d34m = double3x4;
+    typealias real = half;
+
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof((d34m)0 * (real)0), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'double3x4'\n");
+    // note: the parser takes d32m() as a function call, this is the "most verxing parse" problem
+    //       float() is understood by the parser as a NumericConstructorExpression because it
+    //       has a list of the tokens representing all fundamental types.
+    //       but, with user defined identifiers, it can't branch into the "intended" context,
+    //       because ALL(*) Antlr4 parsers recognizes context free grammar only.
+    //       We could adopt a universal construction syntax Obj{} a la C++ for AZSL but it's not
+    //       compliant with the philosophy of not deviating from HLSL, which on its side doesn't really
+    //       accept constructor-type constructs. DXC just does it better here because clang parser is Turing complete.
+
+    d34m d34var;
+    real rVar;
+
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof(rVar - d34var ), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'double3x4'\n");
 }