Explorar o código

Merge pull request #84 from o3de/auto-option-ranks

Auto option ranks
siliconvoodoo %!s(int64=2) %!d(string=hai) anos
pai
achega
84e910ff19

+ 1 - 0
src/AzslcBackend.cpp

@@ -449,6 +449,7 @@ namespace AZ::ShaderCompiler
             // We reserve the right to change it in the future so we make it explicit attribute here
             shaderOption["order"] = optionOrder;
             optionOrder++;
+            shaderOption["costImpact"] = varInfo->m_estimatedCostImpact;
 
             bool isUdt = IsUserDefined(varInfo->GetTypeClass());
             assert(isUdt || IsPredefinedType(varInfo->GetTypeClass()));

+ 0 - 1
src/AzslcEmitter.cpp

@@ -1277,7 +1277,6 @@ namespace AZ::ShaderCompiler
         const ICodeEmissionMutator* codeMutator = m_codeMutator;
 
         ssize_t ii = interval.a;
-        bool wasInPreprocessorDirective = false;  // record a state to detect exit of directives, because they need to reside on their own lines
         while (ii <= interval.b)
         {
             auto* token = GetNextToken(ii /*inout*/);

+ 1 - 3
src/AzslcIntermediateRepresentation.cpp

@@ -332,7 +332,6 @@ namespace AZ::ShaderCompiler
                 cout << "  storage: " << sub.m_typeInfoExt.m_qualifiers.GetDisplayName() << "\n";
                 cout << "  array dim: \"" << sub.m_typeInfoExt.m_arrayDims.ToString() << "\"\n";
                 cout << "  has sampler state: " << (sub.m_samplerState ? "yes\n" : "no\n");
-                cout << "\n";
                 if (!holds_alternative<monostate>(sub.m_constVal))
                 {
                     cout << "  val: " << ExtractValueAsInt64(sub.m_constVal) << "\n";
@@ -519,7 +518,7 @@ namespace AZ::ShaderCompiler
             if (varInfo.GetTypeClass() == TypeClass::Enum)
             {
                 auto* asClassInfo = GetSymbolSubAs<ClassInfo>(varInfo.GetTypeId().GetName());
-                size = asClassInfo->Get<EnumerationInfo>()->m_underlyingType.m_arithmeticInfo.GetBaseSize();
+                size = asClassInfo->Get<EnumerationInfo>()->m_underlyingType.m_arithmeticInfo.m_baseSize;
             }
 
             nextMemberStartingOffset = Packing::PackNextChunk(layoutPacking, size, startAt);
@@ -960,5 +959,4 @@ namespace AZ::ShaderCompiler
         }
         return memberList[memberList.size() - 1];
     }
-
 }  // end of namespace AZ::ShaderCompiler

+ 0 - 1
src/AzslcIntermediateRepresentation.h

@@ -14,7 +14,6 @@
 
 namespace AZ::ShaderCompiler
 {
-
     //! We limit the maximum number of render targets to 8, with indices in the range [0..7]
     static const uint32_t kMaxRenderTargets = 8;
 

+ 3 - 1
src/AzslcKindInfo.h

@@ -248,7 +248,7 @@ namespace AZ::ShaderCompiler
         //! Get the size of a single element, ignoring array dimensions
         const uint32_t GetSingleElementSize(Packing::Layout layout, bool defaultRowMajor) const
         {
-            auto baseSize = m_coreType.m_arithmeticInfo.GetBaseSize();
+            auto baseSize = m_coreType.m_arithmeticInfo.m_baseSize;
             bool isRowMajor = (m_mtxMajor == Packing::MatrixMajor::RowMajor ||
                               (m_mtxMajor == Packing::MatrixMajor::Default && defaultRowMajor));
             auto rows = m_coreType.m_arithmeticInfo.m_rows;
@@ -399,6 +399,7 @@ namespace AZ::ShaderCompiler
         ConstNumericVal            m_constVal;   // (attempted folded) initializer value for simple scalars
         optional<SamplerStateDesc> m_samplerState;
         ExtendedTypeInfo           m_typeInfoExt;
+        int                        m_estimatedCostImpact = -1;  //!< Cached value calculated by AnalyzeOptionRanks
     };
 
     // VarInfo methods definitions
@@ -791,6 +792,7 @@ namespace AZ::ShaderCompiler
         vector< IdentifierUID >   m_overrides;                //!< list of implementing functions in child classes
         optional< IdentifierUID > m_base;   //!< points to the overridden function in the base interface, if applies. only supports one base
         FunctionMultiForwards     m_multiFwds    = FMF_None;  //!< presence of redundant prototype-only declarations
+        int                       m_costScore    = -1;        //!< heuristical static analysis of the amount of instructions contained
         struct Parameter
         {
             IdentifierUID m_varId;

+ 1 - 2
src/AzslcMain.cpp

@@ -23,8 +23,7 @@ namespace StdFs = std::filesystem;
 // For large features or milestones. Minor version allows for breaking changes. Existing tests can change.
 #define AZSLC_MINOR "8"   // last change: introduction of class inheritance
 // For small features or bug fixes. They cannot introduce breaking changes. Existing tests shouldn't change.
-#define AZSLC_REVISION "17"  // last change: fixup alignment check logic_error because of lack of an inter-scope check limiter.
-                    // "16"          change: fixup runtime error with redundant function declarations
+#define AZSLC_REVISION "18"  // last change: automatic option ranks
 
 namespace AZ::ShaderCompiler
 {

+ 223 - 19
src/AzslcReflection.cpp

@@ -589,7 +589,7 @@ namespace AZ::ShaderCompiler
             else if (varInfo.GetTypeClass() == TypeClass::Enum)
             {
                 auto* asClassInfo = m_ir->GetSymbolSubAs<ClassInfo>(varInfo.GetTypeId().GetName());
-                size = asClassInfo->Get<EnumerationInfo>()->m_underlyingType.m_arithmeticInfo.GetBaseSize();
+                size = asClassInfo->Get<EnumerationInfo>()->m_underlyingType.m_arithmeticInfo.m_baseSize;
             }
 
             offset = Packing::PackNextChunk(layoutPacking, size, startAt);
@@ -629,7 +629,9 @@ namespace AZ::ShaderCompiler
 
     void CodeReflection::DumpVariantList(const Options& options) const
     {
+        AnalyzeOptionRanks();
         m_out << GetVariantList(options);
+        m_out << "\n";
     }
 
     static void ReflectBinding(Json::Value& output, const RootSigDesc::SrgParamDesc& bindInfo)
@@ -857,11 +859,12 @@ namespace AZ::ShaderCompiler
         for (auto& seenat : kindInfo->GetSeenats())
         {
             assert(uid == seenat.m_referredDefinition);
-            // TODO: the assumption that intervals where distinct doesnt hold anymore now that we have unnamed scopes
-            auto intervalIter = FindInterval(scopes, seenat.m_where.m_focusedTokenId, [](ssize_t key, auto& value)
-                                             {
-                                                 return value.first.properlyContains({key, key});
-                                             });
+            // careful of the invariant: distinct intervals. (can't support functions nested in functions nor imbricated block scopes)
+            // ok for now because AZSL/HLSL don't have lambdas
+            auto intervalIter = FindIntervalInDisjointSet(scopes, seenat.m_where.m_focusedTokenId, [](ssize_t key, auto& value)
+                                                          {
+                                                              return value.first.properlyContains({key, key});
+                                                          });
             if (intervalIter != scopes.cend())
             {
                 const IdentifierUID& encloser = intervalIter->second.second;
@@ -909,16 +912,9 @@ namespace AZ::ShaderCompiler
         uint32_t numOf32bitConst  = GetNumberOf32BitConstants(options, m_ir->m_rootConstantStructUID);
         RootSigDesc rootSignature = BuildSignatureDescription(options, numOf32bitConst);
 
-        // prepare a lookup acceleration data structure for reverse mapping tokens to scopes.
-        MapOfBeginToSpanAndUid scopeStartToFunctionIntervals;
-        for (auto& [uid, interval] : m_ir->m_scope.m_scopeIntervals)
-        {
-            if (m_ir->GetKind(uid) == Kind::Function)  // Filter out unnamed blocs and types. We need a set of disjoint intervals as an invariant for the next algorithm.
-            {
-                // the reason to choose .a as the key is so we can query using Infimum (sort of lower_bound)
-                scopeStartToFunctionIntervals[interval.a] = std::make_pair(interval, uid);
-            }
-        }
+        // Prepare a lookup acceleration data structure for reverse mapping tokens to scopes.
+        // (truth: we need a set of disjoint intervals as an invariant for the following algorithm)
+        GenerateTokenScopeIntervalToUidReverseMap();
 
         Json::Value srgRoot(Json::objectValue);
         // Order the reflection by SRG for convenience
@@ -968,7 +964,7 @@ namespace AZ::ShaderCompiler
                     else
                     {
                         set<IdentifierUID> dependencyList;
-                        DiscoverTopLevelFunctionDependencies(srgParam.m_uid, dependencyList, scopeStartToFunctionIntervals);
+                        DiscoverTopLevelFunctionDependencies(srgParam.m_uid, dependencyList, m_functionIntervals);
                         srgMember[srgParam.m_uid.GetNameLeaf()] = makeJsonNodeForOneResource(dependencyList, srgParam, {});
                     }
                 }
@@ -981,7 +977,7 @@ namespace AZ::ShaderCompiler
                     for (auto& srgConstant : srgInfo->m_implicitStruct.GetMemberFields())
                     {
                         allConstants.append({ srgConstant.GetNameLeaf() });
-                        DiscoverTopLevelFunctionDependencies(srgConstant, dependencyList, scopeStartToFunctionIntervals);
+                        DiscoverTopLevelFunctionDependencies(srgConstant, dependencyList, m_functionIntervals);
                     }
                     // variant fallback support
                     if (srgInfo->m_shaderVariantFallback)
@@ -992,7 +988,7 @@ namespace AZ::ShaderCompiler
                         {
                             if (varSub->CheckHasStorageFlag(StorageFlag::Option))
                             {
-                                DiscoverTopLevelFunctionDependencies(varUid, dependencyList, scopeStartToFunctionIntervals);
+                                DiscoverTopLevelFunctionDependencies(varUid, dependencyList, m_functionIntervals);
                             }
                         }
                     }
@@ -1004,4 +1000,212 @@ namespace AZ::ShaderCompiler
 
         m_out << srgRoot;
     }
+
+    // Helper routine for option rank analysis
+    static int GuesstimateIntrinsicFunctionCost(string_view funcName)
+    {
+        if (IsOneOf(funcName, "CallShader", "TraceRay"))
+        { // non measurable but assumed high
+            return 100;
+        }
+        else if (IsOneOf(funcName, "Sample", "Load", "InterlockedCompareStore", "InterlockedCompareExchange", "InterlockedExchange", "Append"))
+        { // memory access, locked or not, will have high latency
+            return 10;
+        }
+        else
+        { // unlisted intrinsics like lerp, log2, cos, distance.. will default to a cost of 1.
+            return 1;
+        }
+    }
+
+    // Helper routine for option rank analysis. When picking AN overload is more useful than forfeiting.
+    // The function GetConcreteFunctionThatMatchesArgumentList forfeits when the overloadset contains
+    // strictly more than 1 concrete function with the queried arity. In our case, we prefer to just pick any.
+    static IdentifierUID PickAnyOverloadThatMatchesArgCount(IntermediateRepresentation* ir,
+                                                            azslParser::FunctionCallExpressionContext* callNode,
+                                                            KindInfo& overload)
+    {
+        IdentifierUID concrete;
+        size_t numArgs = NumArgs(callNode);
+        overload.GetSubAs<OverloadSetInfo>()->AnyOf(
+            [&](IdentifierUID const& uid)
+            {
+                auto* concreteFcInfo = ir->GetSymbolSubAs<FunctionInfo>(uid.GetName());
+                size_t numParams = concreteFcInfo->GetParameters(true).size();
+                if (numParams == numArgs)
+                {
+                    concrete = uid;  // we write the result through reference capture (not clean but convenient)
+                    return true;
+                }
+                return false;
+            });
+        return concrete;
+    }
+
+    void CodeReflection::AnalyzeOptionRanks() const
+    {
+        // make sure we have the scope lookup cache ready
+        GenerateTokenScopeIntervalToUidReverseMap();
+        // loop over variables
+        for (auto& [uid, varInfo, kindInfo] : m_ir->m_symbols.GetOrderedSymbolsOfSubType_3<VarInfo>())
+        {
+            // only options
+            if (varInfo->CheckHasStorageFlag(StorageFlag::Option))
+            {
+                int impactScore = 0;
+                // loop over appearances over the program
+                for (Seenat& ref : kindInfo->GetSeenats())
+                {
+                    // determine an impact score
+                    impactScore += AnalyzeImpact(ref.m_where)  // dependent code that may be skipped depending on the value of that ref
+                        + 1;  // by virtue of being mentioned (seenat), we count the reference as an access of cost 1.
+                }
+                varInfo->m_estimatedCostImpact = impactScore;
+            }
+        }
+    }
+
+    int CodeReflection::AnalyzeImpact(TokensLocation const& location) const
+    {
+        // find the node at `location`:
+        ParserRuleContext* node = m_ir->m_tokenMap.GetNode(location.m_focusedTokenId);
+        // go up tree to meet a block node that has visitable depth:
+        // can be any of if/for/while/switch
+        //  4 is an arbitrary depth, enough to search up things like `for (a, b<(ref+1), c)` binaryop->braces->cmpexpr->cond->for
+        if (auto* whileNode = DeepParentAs<azslParser::WhileStatementContext*>(node->parent, 3))
+        {
+            node = whileNode->embeddedStatement();
+        }
+        else if (auto* ifNode = DeepParentAs<azslParser::IfStatementContext*>(node->parent, 3))
+        {
+            node = ifNode->embeddedStatement();
+        }
+        else if (auto* forNode = DeepParentAs<azslParser::ForStatementContext*>(node->parent, 4))
+        {
+            node = forNode->embeddedStatement();
+        }
+        else if (auto* switchNode = DeepParentAs<azslParser::SwitchStatementContext*>(node->parent, 3))
+        {
+            node = switchNode->switchBlock();
+        }
+        int score = 0;
+        AnalyzeImpact(node, score);
+        return score;
+    }
+
+    void CodeReflection::AnalyzeImpact(ParserRuleContext* astNode, int& scoreAccumulator) const
+    {
+        for (auto& c : astNode->children)
+        {
+            if (auto* callNode = As<azslParser::FunctionCallExpressionContext*>(c))
+            {
+                // branch into an overload specialized for function lookup:
+                AnalyzeImpact(callNode, scoreAccumulator);
+            }
+            else if (auto* node = As<ParserRuleContext*>(c))
+            {
+                AnalyzeImpact(node, scoreAccumulator); // recurse down to make sure to capture embedded calls, like e.g. "x ? f() : 0;"
+            }
+            if (auto* leaf = As<tree::TerminalNode*>(c))
+            {
+                // determine cost by number of full expressions separated by semicolon
+                scoreAccumulator += leaf->getSymbol()->getType() == azslLexer::Semi;  // bool as 0 or 1 trick
+            }
+        }
+    }
+
+    void CodeReflection::AnalyzeImpact(azslParser::FunctionCallExpressionContext* callNode, int& scoreAccumulator) const
+    {
+        // to access the function symbol info we need the current scope, the function call name and perform a lookup.
+
+        // figure out the scope at this token.
+        // theoretically should be something in the like of the body of another function,
+        // or an anonymous block within another function.
+        auto interval = m_intervals.GetClosestIntervalSurrounding(callNode->start->getTokenIndex());
+        if (!interval.IsEmpty())
+        {
+            IdentifierUID encloser = m_intervalToUid[interval];
+
+            // Because we are past the end of the semantic analysis,
+            // the scope tracker is registering the last seen scope (surely "/").
+            // This is a stateful side-effect system unfortunately, and since we'll call
+            // some feature of the semantic orchestrator (like TypeofExpr) we need to hack
+            // the scope tracker:
+            m_ir->m_sema.m_scope->m_currentScopePath = encloser.GetName();
+            m_ir->m_sema.m_scope->UpdateCurScopeUID();
+
+            QualifiedName startupLookupScope = encloser.GetName();
+            UnqualifiedName funcName;
+            if (auto* idExpr = As<azslParser::IdentifierExpressionContext*>(callNode->Expr))
+            {
+                funcName = ExtractNameFromIdExpression(idExpr->idExpression());
+            }
+            else if (auto* maeExpr = As<AstMemberAccess*>(callNode->Expr))
+            {
+                startupLookupScope = m_ir->m_sema.TypeofExpr(maeExpr->LHSExpr);
+                funcName = ExtractNameFromIdExpression(maeExpr->Member);
+            }
+            IdAndKind* overload = m_ir->m_symbols.LookupSymbol(startupLookupScope, funcName);
+            if (!overload) // in case of function not found, we assume it's an intrinsic.
+            {
+                scoreAccumulator += GuesstimateIntrinsicFunctionCost(funcName);
+            }
+            else
+            {
+                azslParser::ArgumentListContext* args = GetArgumentListIfBelongsToFunctionCall(callNode);
+                IdAndKind* symbolMeantUnderCallNode = m_ir->m_sema.ResolveOverload(overload, args);
+                IdentifierUID concrete;
+                if (!symbolMeantUnderCallNode || m_ir->GetKind(symbolMeantUnderCallNode->first) == Kind::OverloadSet)
+                {   // in case of strict selection failure, run a fuzzy select
+                    concrete = PickAnyOverloadThatMatchesArgCount(m_ir, callNode, overload->second);
+                    // if still not enough to get a fix (concrete=={}), it might be an ill-formed input. prefer to forfeit
+                }
+                else
+                {
+                    concrete = symbolMeantUnderCallNode->first;
+                }
+
+                if (auto* funcInfo = m_ir->GetSymbolSubAs<FunctionInfo>(concrete.GetName()))
+                {
+                    if (funcInfo->m_costScore == -1)  // cost not yet discovered for this function
+                    {
+                        funcInfo->m_costScore = 0;
+                        using AstFDef = azslParser::HlslFunctionDefinitionContext;
+                        AnalyzeImpact(polymorphic_downcast<AstFDef*>(funcInfo->m_defNode->parent)->block(),
+                                      funcInfo->m_costScore);  // recurse and cache
+                    }
+                    scoreAccumulator += funcInfo->m_costScore;
+                }
+            }
+            // other cases forfeited for now, but that would at least include things like eg braces (f)()
+        }
+        else // no interval found
+        {
+            // function calls outside of function bodies can appear in an initializer:
+            //    int g_a = MakeA();  // global init
+            //    class C { int m_a = CompA();  // constructor init (invalid AZSL/HLSL)
+            //    class D { void Method(int a_a = DefaultA());  // default parameter value
+            // in any case, extracting the scope is impossible with this system.
+            // we forfeit evaluation of a score
+        }
+    }
+
+    void CodeReflection::GenerateTokenScopeIntervalToUidReverseMap() const
+    {
+        if (m_functionIntervals.empty())
+        {
+            for (auto& [uid, interval] : m_ir->m_scope.m_scopeIntervals)
+            {
+                if (m_ir->GetKind(uid) == Kind::Function)  // Filter out unnamed blocs and types.
+                {
+                    // the reason to choose .a as the key is so we can query using Infimum (sort of lower_bound)
+                    m_functionIntervals[interval.a] = std::make_pair(interval, uid);
+                }
+                auto i = Interval<ssize_t>{interval.a, interval.b};
+                m_intervals.Add(i);
+                m_intervalToUid[i] = uid;
+            }
+            m_intervals.Seal();
+        }
+    }
 }

+ 21 - 1
src/AzslcReflection.h

@@ -11,6 +11,9 @@
 
 namespace AZ::ShaderCompiler
 {
+    using MapOfBeginToSpanAndUid = map<ssize_t, pair< misc::Interval, IdentifierUID> >;
+    using MapOfIntervalToUid = map<Interval<ssize_t>, IdentifierUID>;
+
     struct CodeReflection : Backend
     {
         CodeReflection(IntermediateRepresentation* ir, TokenStream* tokens, std::ostream& out)
@@ -45,6 +48,9 @@ namespace AZ::ShaderCompiler
         //! @param options  user configuration parsed from command line
         void DumpResourceBindingDependencies(const Options& options) const;
 
+        //! Determine a heurisitcal global order between options in the program, using "impacted code size" static analysis.
+        void AnalyzeOptionRanks() const;
+
     private:
 
         //! Builds member variable packing information and adds it to the membersContainer
@@ -63,7 +69,6 @@ namespace AZ::ShaderCompiler
 
         bool BuildOMStruct(const ExtendedTypeInfo& returnTypeRef, string_view semanticOverride, Json::Value& jsonVal, int& semanticIndex) const;
 
-        using MapOfBeginToSpanAndUid = map<ssize_t, pair< misc::Interval, IdentifierUID> >;
         //! Populate a list of functions where a symbol appear as potentially used
         //! @param uid      The symbol to start the dependency analysis on
         //! @param output   Any dependency symbol will be appended to this set
@@ -75,6 +80,21 @@ namespace AZ::ShaderCompiler
 
         bool IsPotentialEntryPoint(const IdentifierUID& uid) const;
 
+        // Estimate a score proportional to how much code is "child" to the AST node at `location`
+        int AnalyzeImpact(TokensLocation const& location) const;
+
+        // Recursive internal detail version
+        void AnalyzeImpact(ParserRuleContext* astNode, int& scoreAccumulator) const;
+
+        // Function-call specific
+        void AnalyzeImpact(azslParser::FunctionCallExpressionContext* callNode, int& scoreAccumulator) const;
+
+        //! Useful for static analysis on dependencies or option ranks
+        void GenerateTokenScopeIntervalToUidReverseMap() const;
+        mutable MapOfBeginToSpanAndUid m_functionIntervals;  //< only functions because they are guaranteed to be disjointed (largely simplifies queries)
+        mutable IntervalCollection<ssize_t> m_intervals;  //< augmented version with anonymous blocks (slower query)
+        mutable MapOfIntervalToUid m_intervalToUid;  //< side by side data since we don't want to weight the interval struct with a payload
+
         std::ostream& m_out;
     };
 }

+ 53 - 14
src/AzslcSemanticOrchestrator.cpp

@@ -1166,7 +1166,10 @@ namespace AZ::ShaderCompiler
                                      As<azslParser::AssignmentExpressionContext*>(ctx),
                                      As<azslParser::NumericConstructorExpressionContext*>(ctx),
                                      As<azslParser::LiteralExpressionContext*>(ctx),
-                                     As<azslParser::LiteralContext*>(ctx));
+                                     As<azslParser::LiteralContext*>(ctx),
+                                     As<azslParser::PrefixUnaryExpressionContext*>(ctx),
+                                     As<azslParser::PostfixUnaryExpressionContext*>(ctx),
+                                     As<azslParser::BinaryExpressionContext*>(ctx));
         }
         catch (AllNull&)
         {
@@ -1175,6 +1178,48 @@ namespace AZ::ShaderCompiler
         }
     }
 
+    QualifiedName SemanticOrchestrator::TypeofExpr(azslParser::PrefixUnaryExpressionContext* ctx) const
+    {
+        // among all possibilities Plus|Minus|Not|Tilde|PlusPlus|MinusMinus
+        // only "Not" returns a bool, the rest is transparent and returns the same type as rhs
+        return ctx->prefixUnaryOperator()->start->getType() == azslLexer::Not ? MangleScalarType("bool")
+            : TypeofExpr(ctx->Expr);
+    }
+
+    QualifiedName SemanticOrchestrator::TypeofExpr(azslParser::PostfixUnaryExpressionContext* ctx) const
+    {
+        return TypeofExpr(ctx->Expr); // in case of x++ or x-- the type is the type of x.
+    }
+
+    QualifiedName SemanticOrchestrator::TypeofExpr(azslParser::BinaryExpressionContext* ctx) const
+    {
+        using lex = azslLexer;
+        auto boolResultOperators = {lex::Less, lex::Greater, lex::LessEqual, lex::GreaterEqual, lex::NotEqual, lex::AndAnd, lex::OrOr};
+        if (IsIn(ctx->binaryOperator()->start->getType(), boolResultOperators))
+        {
+            return MangleScalarType("bool");
+        }
+        QualifiedName lhs = TypeofExpr(ctx->Left);
+        QualifiedName rhs = TypeofExpr(ctx->Right);
+        TypeRefInfo typeInfoLhs = CreateTypeRefInfo(UnqualifiedNameView{lhs});  // We tolerate a cast here because GetTypeRefInfo was designed to lookup types, but TypeofExpr has already looked up the type.
+        TypeRefInfo typeInfoRhs = CreateTypeRefInfo(UnqualifiedNameView{rhs});
+        if (typeInfoLhs.m_arithmeticInfo.IsEmpty() || typeInfoRhs.m_arithmeticInfo.IsEmpty())
+        {   // Case that shouldn't work in AZSL yet (but may work in HLSL2021)
+            //  -> UDT operator (would need support of operator overloading).
+            // We arbitrarily assume a result "type of left expression".
+            // (what we would really need to do is go get the return type of the overloaded operator)
+            return lhs;
+        } // After this `if`, both sides are arithmetic types (scalar, vector, matrix).
+        // "matrix op vector" (or commutated) is a forbidden case,
+        //   e.g it won't do Y=MX (m*v->v), nor dotproduct-ing vectors for that matter (v*v->scalar)
+        // It will do piecewise `op` and return more or less the same type as the operands.
+        // In case of dimension differences, it will truncate to the smaller type
+        //   e.g float2 + float3 results in float2 with .z lost in implicit cast
+        //      same for float2x3 * float2x2 (results in float2x2)
+        // We assume that for all non bool ops: * + - / % ^ | & << >>
+        return PromoteTruncateWith({typeInfoLhs.m_arithmeticInfo, typeInfoRhs.m_arithmeticInfo}).GenTypeId();
+    }
+
     QualifiedName SemanticOrchestrator::TypeofExpr(azslParser::ExpressionExtContext* ctx) const
     {
         return VisitFirstNonNull([this](auto* ctx) { return TypeofExpr(ctx); },
@@ -1303,29 +1348,23 @@ namespace AZ::ShaderCompiler
 
     QualifiedName SemanticOrchestrator::TypeofExpr(azslParser::LiteralContext* ctx) const
     {
-        // verifies that our hardcoded strings don't have typo, by checking against the lexer-extracted keywords stored in the Scalar array.
-        auto checkExistType = [](string_view scalarName){return std::find(AZ::ShaderCompiler::Predefined::Scalar.begin(), AZ::ShaderCompiler::Predefined::Scalar.end(), scalarName) != AZ::ShaderCompiler::Predefined::Scalar.end();};
         // verifies that last or 1-before-last characters are a particular literal suffix. like in "56ul"
         auto hasSuffix = [](auto node, char s){return tolower(node->getText().back()) == s || tolower(Slice(node->getText(), -3, -2) == s);};
-        auto checkAndReturn = [&](string_view typeName)
-                              {
-                                  assert(checkExistType(typeName));
-                                  return QualifiedName{"?" + string{typeName}};
-                              };
+
         if (ctx->True() || ctx->False())
         {
-            return checkAndReturn("bool");
+            return MangleScalarType("bool");
         }
         else if (auto* literal = ctx->IntegerLiteral())
         {
-            return hasSuffix(literal, 'u') ? checkAndReturn("uint")
-                 : checkAndReturn("int");
+            return hasSuffix(literal, 'u') ? MangleScalarType("uint")
+                 : MangleScalarType("int");
         }
         else if (auto* literal = ctx->FloatLiteral())
         {
-            return hasSuffix(literal, 'h') ? checkAndReturn("half")
-                 : hasSuffix(literal, 'l') ? checkAndReturn("double")
-                 : checkAndReturn("float");
+            return hasSuffix(literal, 'h') ? MangleScalarType("half")
+                 : hasSuffix(literal, 'l') ? MangleScalarType("double")
+                 : MangleScalarType("float");
         }
         return {"<fail>"};
     }

+ 4 - 1
src/AzslcSemanticOrchestrator.h

@@ -222,6 +222,9 @@ namespace AZ::ShaderCompiler
         auto TypeofExpr(azslParser::LiteralExpressionContext* ctx) const -> QualifiedName;
         auto TypeofExpr(azslParser::LiteralContext* ctx) const -> QualifiedName;
         auto TypeofExpr(azslParser::CommaExpressionContext* ctx) const -> QualifiedName;
+        auto TypeofExpr(azslParser::PostfixUnaryExpressionContext* ctx) const -> QualifiedName;
+        auto TypeofExpr(azslParser::PrefixUnaryExpressionContext* ctx) const -> QualifiedName;
+        auto TypeofExpr(azslParser::BinaryExpressionContext* ctx) const -> QualifiedName;
 
         //! Parse the AST from a variable declaration and attempt to extract array dimensions integer constants [dim1][dim2]...
         //! Return: <true> on success, <false> otherwise
@@ -330,7 +333,7 @@ namespace AZ::ShaderCompiler
         {
             auto typeId      = LookupType(typeNameOrCtx, policy);
             auto tClass      = GetTypeClass(typeId);
-            auto arithmetic  = IsNonGenericArithmetic(tClass) ? CreateArithmeticTypeInfo(typeId.GetName()) : ArithmeticTypeInfo{}; // TODO: canonicalize generics
+            auto arithmetic  = IsNonGenericArithmetic(tClass) ? CreateArithmeticTraits(typeId.GetName()) : ArithmeticTraits{}; // TODO: canonicalize generics
             return TypeRefInfo{typeId, arithmetic, tClass};
         }
 

+ 2 - 2
src/AzslcSymbolAggregator.cpp

@@ -20,7 +20,7 @@ namespace AZ::ShaderCompiler
             auto& [uid, kindInfo] = st.AddIdentifier(azirName, Kind::Type);  // the kind is Type because all predefined are stored as such.
             auto& typeInfo        = kindInfo.GetSubAfterInitAs<Kind::Type>();
             auto typeClass        = TypeClass::FromStr(bag.m_name);
-            auto arithmetic       = IsNonGenericArithmetic(typeClass) ? CreateArithmeticTypeInfo(azirName) : ArithmeticTypeInfo{}; // TODO: canonicalize generics
+            auto arithmetic       = IsNonGenericArithmetic(typeClass) ? CreateArithmeticTraits(azirName) : ArithmeticTraits{}; // TODO: canonicalize generics
             typeInfo = TypeRefInfo{uid, arithmetic, typeClass};
         }
     }
@@ -140,7 +140,7 @@ namespace AZ::ShaderCompiler
         {
             auto attempt = QualifiedName{JoinPath(path, name)};
             got = GetIdAndKindInfo(attempt);
-            exit = path == "/";
+            exit = path == "/" || path.empty();
             if (!got)
             {
                 if (auto* scopeAsClass = GetAsSub<ClassInfo>(IdentifierUID{GetParentName(attempt)})) // get enclosing class

+ 131 - 30
src/AzslcTypes.h

@@ -260,36 +260,84 @@ namespace AZ::ShaderCompiler
         return result;
     }
 
-    /// Rows and Cols (this is specific to shader languages to identify vector and matrix types)
-    struct ArithmeticTypeInfo
+    /// Verifies that our hardcoded strings don't have typo, by checking against the lexer-extracted keywords stored in the Scalar array.
+    inline bool CheckExistScalarType(string_view scalarName)
     {
-        void ResolveSize()
-        {
-            m_size = Packing::PackedSizeof(m_underlyingScalar);
-        }
+        return std::find(Predefined::Scalar.begin(),
+                         Predefined::Scalar.end(),
+                         scalarName) != Predefined::Scalar.end();
+    };
+
+    /// Assert validity of the type string, and form a "?<type>" tainted string to host a scalar type
+    inline QualifiedName MangleScalarType(string_view typeName)
+    {
+        assert(CheckExistScalarType(typeName));
+        return QualifiedName{"?" + string{typeName}};
+    };
 
-        /// Get the size of a single base element
-        const uint32_t GetBaseSize() const
+    //! Holds arithmetic-type-class information in small pieces (row, cols, base, size, rank...)
+    struct ArithmeticTraits
+    {
+        void ResolveBaseSizeAndRank()
         {
-            return m_size;
+            m_baseSize = Packing::PackedSizeof(m_underlyingScalar);
+            // establish the conversion rank:
+            auto getIndex = [](string_view s) -> int
+            {
+                auto const& Scalars = Predefined::Scalar;
+                return ::std::distance(Scalars.begin(),
+                                       ::std::find(Scalars.begin(), Scalars.end(), s));
+            };
+            // According to https://en.cppreference.com/w/cpp/language/usual_arithmetic_conversions
+            //   - No two signed have the same rank (even if same siezeof)
+            //   - rank of unsigned = rank of corresponding signed
+            //   - "standard" is > "extended" of same sizeof
+            //   - The rank of bool is the smallest
+            // That said, we will take inspiration from ASTContext::getIntegerRank of clang
+            static const unordered_map<int, int> subranks =
+            {
+                {getIndex("bool"), 1},
+                {getIndex("int16_t"), 2},
+                {getIndex("uint16_t"), 3}, // unsigned wins in case of subrank draw, according to arithmetic conversion rules
+                {getIndex("int"), 4},
+                {getIndex("uint"), 5},
+                {getIndex("dword"), 6},
+                {getIndex("int32_t"), 7},
+                {getIndex("uint32_t"), 8},
+                {getIndex("int64_t"), 9},
+                {getIndex("uint64_t"), 10},
+                {getIndex("half"), 11 << 5},  // floats win all conversions, even halfs
+                {getIndex("float"), 12 << 5},
+                {getIndex("double"), 13 << 5},
+            };
+            // `basesize` getter, but 1 for bool: (physical size of extern bool is considered 32bits in HLSL)
+            auto getRankSizeof = [&](int scalarId)
+            {
+                assert(string_view{"bool"} == Predefined::Scalar[0]);  // verify that 0 is the hard index of bool.
+                bool isBool = scalarId == 0;
+                return isBool ? 1 : m_baseSize;
+            };
+            // The shift method is taken from clang, I suppose it's a multi-parameter order cramed into bits.
+            // so because 10 is the largest subrank, shift by 4 should separate sizeof space and subrank space.
+            m_conversionRank = (getRankSizeof(m_underlyingScalar) << 4) + subranks.at(m_underlyingScalar);
         }
 
-                /// Get the size of a the element with regard to dimensions as well
-        const uint32_t GetTotalSize() const
+        //! Get the size of the whole type considering dimensions
+        uint32_t GetTotalSize() const
         {
-            return m_size * (m_cols > 0 ? m_cols : 1) * (m_rows > 0 ? m_rows : 1);
+            return m_baseSize * (m_cols > 0 ? m_cols : 1) * (m_rows > 0 ? m_rows : 1);
         }
 
-        /// True if the type is a vector type. If it's a vector type it cannot be a matrix as well.
-        const bool IsVector() const
+        //! True if the type is a vector type. If it's a vector type it cannot be a matrix as well.
+        bool IsVector() const
         {
             // This treats special cases like 2x1, 3x1 and 4x1 as vectors
             // The behavior is consistent with dxc packing rules
             return (m_cols == 1 && m_rows > 1) || (m_cols > 1 && m_rows == 0);
         }
 
-        /// True if the type is a matrix type. If it's a matrix type it cannot be a vector as well.
-        const bool IsMatrix() const
+        //! True if the type is a matrix type. If it's a matrix type it cannot be a vector as well.
+        bool IsMatrix() const
         {
             // This fails special cases like 2x1, 3x1 and 4x1,
             //   but allows cases like 1x2, 1x3 and 1x4.
@@ -297,23 +345,46 @@ namespace AZ::ShaderCompiler
             return m_rows > 0 && m_cols > 1;
         }
 
-        /// If initialized as a fundamental -> not empty.
-        const bool IsEmpty() const
+        bool IsScalar() const
+        {
+            return m_rows <= 1 && m_cols <= 1;  // float, float1, float1x1 are scalars.
+        }
+
+        //! Non-created state
+        bool IsEmpty() const
         {
             return m_underlyingScalar == -1;
         }
 
-        // for pretty print
+        //! For pretty print
         string UnderlyingScalarToStr() const
         {
             return m_underlyingScalar >= 0 && m_underlyingScalar < AZ::ShaderCompiler::Predefined::Scalar.size() ?
                 AZ::ShaderCompiler::Predefined::Scalar[m_underlyingScalar] : "<NA>";
         }
 
-        uint32_t m_size = 0;                     // In bytes. Size of 0 indicates TypeRefInfo which hasn't been resolved or is a struct
-        uint32_t m_rows = 0;                     // 0 means it's not a matrix (effective Rows = 1). 1 or more means a Matrix
-        uint32_t m_cols = 0;                     // 0 means it's not a vector (effective Cols = 1). 1 or more means a Vector or Matrix
-        int      m_underlyingScalar = -1;        // index into AZ::ShaderCompiler::Predefined::Scalar, all fundamentals end up in a scalar at its leaf.
+        //! Create a canonicalized mangled name that should represent the identity of this arithmetic type.
+        QualifiedName GenTypeId() const
+        {
+            if (IsMatrix())
+            {
+                return QualifiedName{MangleScalarType(UnderlyingScalarToStr()) + ToString(m_rows) + "x" + ToString(m_cols)};
+            }
+            else if (IsVector())
+            {
+                return QualifiedName{MangleScalarType(UnderlyingScalarToStr()) + (m_rows > 0 ? ToString(m_rows) : ToString(m_cols))};
+            }
+            else
+            {
+                return QualifiedName{MangleScalarType(UnderlyingScalarToStr())};
+            }
+        }
+
+        uint32_t m_baseSize = 0;                 //< In bytes. Size of 0 indicates TypeRefInfo which hasn't been resolved or is a struct
+        uint32_t m_rows = 0;                     //< 0 means it's not a matrix (effective Rows = 1). 1 or more means a Matrix
+        uint32_t m_cols = 0;                     //< 0 means it's not a vector (effective Cols = 1). 1 or more means a Vector or Matrix
+        int      m_conversionRank = 0;           //< Used in conversions and promotions
+        int      m_underlyingScalar = -1;        //< Index into AZ::ShaderCompiler::Predefined::Scalar, all fundamentals end up in a scalar at its leaf.
     };
 
     //! TypeRefInfo holds resolved immutable information of a core type (the `matrix2x2` in `column_major matrix2x2 a[3];`)
@@ -323,7 +394,7 @@ namespace AZ::ShaderCompiler
     struct TypeRefInfo
     {
         TypeRefInfo() = default;
-        TypeRefInfo(IdentifierUID typeId, const ArithmeticTypeInfo& fundamentalInfo, TypeClass typeClass)
+        TypeRefInfo(IdentifierUID typeId, const ArithmeticTraits& fundamentalInfo, TypeClass typeClass)
             : m_arithmeticInfo{fundamentalInfo},
               m_typeClass{typeClass},
               m_typeId{typeId}
@@ -359,19 +430,19 @@ namespace AZ::ShaderCompiler
             return !operator==(lhs,rhs);
         }
 
-        IdentifierUID       m_typeId;
-        TypeClass           m_typeClass;
-        ArithmeticTypeInfo  m_arithmeticInfo;
+        IdentifierUID     m_typeId;
+        TypeClass         m_typeClass;
+        ArithmeticTraits  m_arithmeticInfo;
     };
 
     //! Run a syntactic analysis of an arithmetic type name and extract info on its composition
-    inline ArithmeticTypeInfo CreateArithmeticTypeInfo(QualifiedName a_typeName)
+    inline ArithmeticTraits CreateArithmeticTraits(QualifiedName a_typeName)
     {
         assert(IsArithmetic( /*slow*/AnalyzeTypeClass(TentativeName{a_typeName}) ));  // no need to call this function if you don't have a fundamental (non void)
         assert(!IsGenericArithmetic( /*slow*/AnalyzeTypeClass(TentativeName{a_typeName}) ));
         // ↑ fatal aspect. The input needs to be canonicalized earlier to minimize this function's complexity.
 
-        ArithmeticTypeInfo toReturn;
+        ArithmeticTraits toReturn;
 
         string typeName = UnMangle(a_typeName);
         size_t baseTypeLen = typeName.length();
@@ -422,10 +493,40 @@ namespace AZ::ShaderCompiler
         auto it = ::std::find(AZ::ShaderCompiler::Predefined::Scalar.begin(), AZ::ShaderCompiler::Predefined::Scalar.end(), baseType);
         assert(it != AZ::ShaderCompiler::Predefined::Scalar.end()); // baseType must exist in the Scalar bag by program invariant.
         toReturn.m_underlyingScalar = static_cast<int>( std::distance(AZ::ShaderCompiler::Predefined::Scalar.begin(), it) );
-        toReturn.ResolveSize();
+        toReturn.ResolveBaseSizeAndRank();
         return toReturn;
     }
 
+    //! Create a new arithmetic traits promoted by necessity (through a binary operation usually)
+    //! of compatibility with a second operand of arithmetic typeclass.
+    //! For example: type(half{} + int{})->half
+    //!              type(float3x3{} * double{})->double3x3
+    //! And with columns & rows truncated to the smallest operand,
+    //!   as part of the implicit necessary cast for operation compatibility.
+    inline ArithmeticTraits PromoteTruncateWith(Pair<ArithmeticTraits> operands)
+    {
+        auto [_1, _2] = operands;
+        // put the scalar last (will be useful later if there is only one scalar operand)
+        SwapIf(_1, _2, _1.IsScalar());
+        // now, let's construct the result in _1
+
+        if (_2.m_conversionRank > _1.m_conversionRank) // the higher ranking underlying wins; independently of full object size
+        {
+            _1.m_underlyingScalar = _2.m_underlyingScalar;
+        }
+        // cases: scalar-scalar : no dim change
+        //        vecmat-scalar : no dim change, since result is dim(vecmat) which is in _1
+        //        scalar-vecmat : impossible (sorted by swap above)
+        //        vecmat-vecmat : min(_1,_2)
+        if (!_1.IsScalar() && !_2.IsScalar())
+        {
+            _1.m_rows = std::min(_1.m_rows, _2.m_rows);
+            _1.m_cols = std::min(_1.m_cols, _2.m_cols);
+        }
+        _1.ResolveBaseSizeAndRank();
+        return _1;
+    }
+
     MAKE_REFLECTABLE_ENUM(RootParamType,
         SRV,               // t
         UAV,               // u

+ 20 - 6
src/AzslcUtils.h

@@ -958,15 +958,22 @@ namespace AZ::ShaderCompiler
         return UnqualifiedName{ctx->Name->getText()};
     }
 
-    template <typename ParentType>
-    bool Is3ParentRuleOfType(antlr4::ParserRuleContext* ctx)
+    //! Get a pointer to the first parent that happens to be of type `SearchType`
+    //! with a limit depth of `maxDepth` parents to search through
+    template <typename SearchType>
+    SearchType DeepParentAs(tree::ParseTree* ctx, int maxDepth)
     {
-        if (ctx == nullptr || ctx->parent == nullptr || ctx->parent->parent == nullptr)  // input canonicalization
+        if (auto* searchTypeNode = As<SearchType>(ctx))
         {
-            return false;
+            return searchTypeNode;
         }
-        auto threeUp = ctx->parent->parent->parent;
-        return dynamic_cast<ParentType>(threeUp);
+        return maxDepth <= 0 || !ctx ? nullptr : DeepParentAs<SearchType>(ctx->parent, maxDepth - 1);
+    }
+
+    template <typename ParentType>
+    bool Is3ParentRuleOfType(tree::ParseTree* ctx)
+    {
+        return DeepParentAs<ParentType>(ctx, 3);
     }
 
     // is def
@@ -1040,6 +1047,13 @@ namespace AZ::ShaderCompiler
         return found ? found->argumentList() : nullptr;
     }
 
+    //! access the argument count at a function call site (from the AST)
+    inline size_t NumArgs(azslParser::FunctionCallExpressionContext* callCtx)
+    {
+        azslParser::ArgumentsContext* argsNode = callCtx->argumentList()->arguments();
+        return argsNode ? argsNode->expression().size() : 0;
+    }
+
     //! try to find a specific context type that this context would be a child of.
     template <typename LookedUp>
     inline LookedUp* ExtractSpecificParent(antlr4::ParserRuleContext* ctx)

+ 178 - 51
src/GenericUtils.h

@@ -24,9 +24,9 @@ namespace AZ
     {
         using runtime_error::runtime_error;
     };
-    // Type-heterogeneity-preserving multi pointer object single visitor.
-    // Returns whatever the passed functor would.
-    // Throws if all passed objects are null.
+    //! Type-heterogeneity-preserving multi pointer object single visitor.
+    //! Returns whatever the passed functor would.
+    //! Throws if all passed objects are null.
     template <typename Lambda, typename T>
     invoke_result_t<Lambda, T*> VisitFirstNonNull(Lambda functor, T* object) noexcept(false)
     {
@@ -50,9 +50,9 @@ namespace AZ
         }
     }
 
-    // Create substring views of views. Works like python slicing operator [n:m] with limited modulo semantics.
-    // what I ultimately desire is the range v.3 feature eg `letters[{2,end-2}]`
-    // http://ericniebler.com/2014/12/07/a-slice-of-python-in-c/
+    //! Create substring views of views. Works like python slicing operator [n:m] with limited modulo semantics.
+    //! what I ultimately desire is the range v.3 feature eg `letters[{2,end-2}]`
+    //! http://ericniebler.com/2014/12/07/a-slice-of-python-in-c/
     inline constexpr string_view Slice(const string_view& in, int64_t st, int64_t end)
     {
         auto inSSize = (int64_t)in.size();
@@ -107,8 +107,8 @@ namespace AZ
     //https://developercommunity.visualstudio.com/content/problem/275141/c2131-expression-did-not-evaluate-to-a-constant-fo.html
     }
 
-    // ability to create size_t literals
-    // waiting for Working Group to get their stuff together https://groups.google.com/a/isocpp.org/forum/#!topic/std-proposals/tGoPjUeHlKo
+    //! ability to create size_t literals
+    //! waiting for Working Group to get their stuff together https://groups.google.com/a/isocpp.org/forum/#!topic/std-proposals/tGoPjUeHlKo
     inline constexpr std::size_t operator ""_sz(unsigned long long n)
     {
         return n;
@@ -145,7 +145,7 @@ namespace AZ
         return fileName.substr(0, lastIndex);
     }
 
-    // debug-build asserted dyn_cast, otherwise, release-build static_cast (idea from boost library)
+    //! debug-build asserted dyn_cast, otherwise, release-build static_cast (idea from boost library)
     template <typename To, typename From>
     To polymorphic_downcast(From instance)
     {
@@ -157,7 +157,7 @@ namespace AZ
         return static_cast<To>(instance);
     }
 
-    /// surround a string with a prefix and a suffix
+    //! surround a string with a prefix and a suffix
     inline string Decorate(string_view prefix, string_view body, string_view suffix)
     {
         std::stringstream ss;
@@ -167,13 +167,13 @@ namespace AZ
         return ss.str();
     }
 
-    /// 2 arguments version in case both sides are the same
+    //! 2 arguments version in case both sides are the same
     inline string Decorate(string_view prefixAndSuffix, string_view body)
     {
         return Decorate(prefixAndSuffix, body, prefixAndSuffix);
     }
 
-    /// reverse the effect of a symmetrical decoration
+    //! reverse the effect of a symmetrical decoration
     inline string_view Undecorate(string_view decoration, string_view body)
     {
         auto indexStart = StartsWith(body, decoration) ? decoration.length() : 0;
@@ -181,7 +181,7 @@ namespace AZ
         return Slice(body, indexStart, indexEnd);
     }
 
-    // Erase-Remove algorithm which removes all whitespaces from a string.
+    //! Erase-Remove algorithm which removes all whitespaces from a string.
     inline string RemoveWhitespaces(string haystack)
     {
         haystack.erase(std::remove_if(haystack.begin(), haystack.end(), [](unsigned char c) {return std::isspace(c); }), haystack.end());
@@ -193,14 +193,14 @@ namespace AZ
         return std::all_of(s.begin(), s.end(), [&](char c) { return std::isspace(c); });
     }
 
-    /// tells whether a position in a string is surrounded by round braces
-    /// e.g. true  for arguments {"a(b)", 2}
-    /// e.g. true  for arguments {"a()", 1}  by convention
-    /// e.g. false for arguments {"a()", 2}  by convention
-    /// e.g. false for arguments {"a(b)", 0}
-    /// e.g. false for arguments {"a(b)c", 4}
-    /// e.g. false for arguments {"a(b)c(d)", 4}
-    /// e.g. true  for arguments {"a((b)c(d))", 5}
+    //! tells whether a position in a string is surrounded by round braces
+    //! e.g. true  for arguments {"a(b)", 2}
+    //! e.g. true  for arguments {"a()", 1}  by convention
+    //! e.g. false for arguments {"a()", 2}  by convention
+    //! e.g. false for arguments {"a(b)", 0}
+    //! e.g. false for arguments {"a(b)c", 4}
+    //! e.g. false for arguments {"a(b)c(d)", 4}
+    //! e.g. true  for arguments {"a((b)c(d))", 5}
     inline bool WithinMatchedParentheses(string_view haystack, size_t charPosition)
     {
         const auto hayLen = haystack.length();
@@ -215,8 +215,8 @@ namespace AZ
         return nesting > 0;
     }
 
-    /// replace all occurrences of substring `sub` with substring `to` within haystack.
-    ///  e.g: Replace("aaa#aaa", "#", "_") gives-> "aaa_aaa"
+    //! replace all occurrences of substring `sub` with substring `to` within haystack.
+    //!  e.g: Replace("aaa#aaa", "#", "_") gives-> "aaa_aaa"
     inline string Replace(string haystack, string_view sub, string_view to)
     {
         decltype(sub.length()) pos = 0;
@@ -230,7 +230,7 @@ namespace AZ
         return haystack;
     }
 
-    // this one is inspired by the docopt utilities. trims whitespace by default, but can be used to trim quotes.
+    //! this one is inspired by the docopt utilities. trims whitespace by default, but can be used to trim quotes.
     constexpr inline string_view Trim(string_view haystack, string_view toTrim = " \t\n")
     {
         const auto strEnd = haystack.find_last_not_of(toTrim);
@@ -268,34 +268,56 @@ namespace AZ
         return std::find_if(begin, end, p) != end;
     }
 
-    /// argument in rangeV3-style version:
+    //! argument in rangeV3-style version:
     template< typename Container >
     string Join(const Container& c, string_view separator = "")
     {
         return Join(c.begin(), c.end(), separator);
     }
 
-    /// argument in rangeV3-style version:
+    //! argument in rangeV3-style version:
     template< typename Container, typename Predicate >
     bool Contains(const Container& c, Predicate p)
     {
         return Contains(c.begin(), c.end(), p);
     }
 
-    /// closest possible form of python's `in` keyword
+    //! closest possible form of python's `in` keyword
     template< typename Element, typename Container >
     bool IsIn(const Element& element, const Container& container)
     {
         return std::find(container.begin(), container.end(), element) != container.end();
     }
 
-    /// generate a new container with copy-and-mutated elements
+    //! generate a new container with copy-and-mutated elements
     template< typename Container, typename ContainerOut, typename Functor >
     void TransformCopy(const Container& in, ContainerOut& out, Functor mutator)
     {
         std::transform(in.begin(), in.end(), std::back_inserter(out), mutator);
     }
 
+    enum class CopyIfPolicy { ForAll, InterruptAtFirstFalse };
+
+    //! inserts elements into the output iterator if they pass a predicate
+    template< typename InputIterator, typename Predicate, typename OutputIterator >
+    void CopyIf(InputIterator begin, InputIterator end,
+                Predicate pred,
+                OutputIterator out,
+                CopyIfPolicy policy)
+    {
+        for (auto it = begin; it != end; ++it)
+        {
+            if (pred(*it))
+            {
+                *out = *it;
+            }
+            else if (policy == CopyIfPolicy::InterruptAtFirstFalse)
+            {
+                break;
+            }
+        }
+    }
+
     inline string Unescape(string_view escapedText)
     {
         std::stringstream out;
@@ -379,14 +401,14 @@ namespace AZ
     // Is-One-Of will check if a variable is equal to any of the values listed on the other parameters.
     // Example: IsOneOf(variableKind, Function, Enumeration) is short for: variableKind == Function || variableKind == Enumeration.
     // 2 arguments count: recursion terminal overload.
-    template <typename T>
-    bool IsOneOf(T value, T tocheck)
+    template <typename T, typename U>
+    bool IsOneOf(T value, U tocheck)
     {
         return value == tocheck;
     }
     // Any argument count version
-    template <typename T, typename... Args>
-    bool IsOneOf(T value, T khead, Args... tail)
+    template <typename T, typename U, typename... Args>
+    bool IsOneOf(T value, U khead, Args... tail)
     {
         return value == khead || IsOneOf(value, tail...);
     }
@@ -531,7 +553,7 @@ namespace AZ
 #endif
     }
 
-    /// Log(N) find immediate lower element query in map-keys
+    //! Log(N) query to find the first immediately lower or equal element in a map's keys
     template< typename T, typename U >
     auto Infimum(map<T, U> const& ctr, T query)
     {
@@ -539,35 +561,114 @@ namespace AZ
         return it == ctr.begin() ? ctr.cend() : --it;
     }
 
-    /// Log(N) disjointed segments belong query
-    /// You can represent segments as you wish, as long as:
-    ///   you provide the predicate to determine belonging.
-    ///   map-keys are segment start points.
-    ///   segments don't overlap.
-    /// returns: iterator to found interval key, or cend()
+    //! Log(N) disjointed segments belong query
+    //! You can represent segments as you wish, as long as:
+    //!   you provide the predicate to determine belonging.
+    //!   map-keys are segment start points.
+    //!   segments don't overlap.
+    //! returns: iterator to found interval key, or cend()
     template< typename T, typename U, typename IntervalCheckPredicate >
-    auto FindInterval(const map<T, U>& ctr, const T& query, IntervalCheckPredicate&& isInIntervalPredicate)
+    auto FindIntervalInDisjointSet(const map<T, U>& ctr, const T& query, IntervalCheckPredicate&& isInIntervalPredicate)
     {
         auto inf = Infimum(ctr, query);
-        return inf == ctr.end() ? ctr.cend() : (isInIntervalPredicate(query, inf->second) ? inf : ctr.cend());
+        bool isInInterval = inf != ctr.cend() && isInIntervalPredicate(query, inf->second);
+        return isInInterval ? inf : ctr.cend();
     }
 
-    /// Log(N) disjointed segments belong query
-    /// segments are represented by their start points in the key, and last point in values. segments can't overlap.
-    /// returns: iterator to found interval key, or cend()
+    //! Log(N) disjointed segments belong query
+    //! segments are represented by their start points in the key, and last point in values. segments can't overlap.
+    //! returns: iterator to found interval key, or cend()
     template< typename T, typename U>
-    auto FindInterval(const map<T, U>& ctr, const T& query)
+    auto FindIntervalInDisjointSet(const map<T, U>& ctr, const T& query)
     {
-        return FindInterval(ctr, query, [](T q, U last) {return q <= last; });
+        return FindIntervalInDisjointSet(ctr, query, [](T q, U last) {return q <= last; });
     }
 
+    template< typename T >
+    struct Interval
+    {
+        bool IsEmpty() const { return b < a; }
+        bool operator== (Interval const& rhs) const { return a == rhs.a && b == rhs.b; }
+        bool operator< (Interval const& rhs) const { return a < rhs.a || (a == rhs.a && b < rhs.b); }
+        T a = (T)0;
+        T b = (T)-1;
+    };
+
+    //! In case of potential overlaps (not disjointed), this structure can support "is in" queries
+    template< typename T >
+    struct IntervalCollection
+    {
+        using IntervalT = Interval<T>;
+
+        void Add(IntervalT i)
+        {
+            m_obfirsts.emplace_back(i);
+        }
+
+        //! doesn't respect RAII but for the sake of performance and convenience this is easier this way
+        void Seal()
+        {
+            m_oblasts = m_obfirsts;
+            std::sort(m_obfirsts.begin(), m_obfirsts.end(), [](auto i1, auto i2)
+                      {
+                          return i1.a < i2.a;
+                      });
+            std::sort(m_oblasts.begin(), m_oblasts.end(), [](auto i1, auto i2)
+                      {
+                          return i1.b < i2.b;
+                      });
+            m_sealed = true;
+        }
+
+        //! Retrieve the subset of intervals activated by a point (query)
+        set<IntervalT> GetIntervalsSurrounding(T query) const
+        {
+            assert(m_sealed);
+            // find the "set" of intervals starting before:
+            auto firstsSubEnd = std::lower_bound(m_obfirsts.begin(), m_obfirsts.end(),
+                                                 query,
+                                                 [=](auto interv, T q) { return interv.a <= q; });
+
+            // find the "set" of intervals ending after:
+            static vector<IntervalT> endAfter;
+            endAfter.clear();
+            CopyIf(m_oblasts.rbegin(), m_oblasts.rend(),  // reverse iteration
+                   [=](auto interv) { return interv.b >= query; },
+                   std::back_inserter(endAfter),
+                   CopyIfPolicy::InterruptAtFirstFalse);
+            // for set_intersection to work, the less<> predicate has to work for both ranges
+            std::sort(endAfter.begin(), endAfter.end());
+
+            set<IntervalT> result;
+            std::set_intersection(m_obfirsts.begin(), firstsSubEnd,
+                                  endAfter.begin(), endAfter.end(),
+                                  std::inserter(result, result.end()));
+            return result;
+        }
+
+        //! Get the interval surrounding query that has the closest start point to query.
+        //! In case of an interval collection representing a tree, that is,
+        //! each overlapping interval is fully contained in the bigger one,
+        //! the closest start is guaranteed to be the most "leaf" interval.
+        //! This is useful for scopes.
+        IntervalT GetClosestIntervalSurrounding(T query) const
+        {
+            auto bag = GetIntervalsSurrounding(query);
+            return bag.empty() ? IntervalT{-1, -2} : *bag.rbegin();
+        }
+
+        vector<IntervalT> m_obfirsts;  // ordered by "firsts"
+        vector<IntervalT> m_oblasts;   // ordered by "lasts"
+        bool m_sealed = false;
+    };
+
     template< typename Deduced >
     decltype(auto) CastToRValueReference(Deduced&& value)
     {
         return static_cast<std::remove_reference_t<Deduced>&&>(value);
     }
 
-    /// add a missing operator for convenience and shortness of code
+    //! add a missing operator for convenience and shortness of code
     inline bool operator == (string_view lhs, char rhs)
     {
         return lhs.length() == 1 && lhs[0] == rhs;
@@ -624,6 +725,15 @@ namespace AZ
         RemoveDuplicatesKeepOrder(lhs);
     }
 
+    //! Conditional swap algorithm
+    template<typename T>
+    void SwapIf(T&& a, T&& b, bool condition)
+    {
+        if (condition)
+        {
+            std::swap(std::forward<T>(a), std::forward<T>(b));
+        }
+    }
 }
 
 #ifndef NDEBUG
@@ -704,10 +814,10 @@ namespace AZ::Tests
             assert(yellow == intervals.cend());
             auto larger_than_all = Infimum(intervals, 15);
             assert(larger_than_all->first == 8);
-            assert(FindInterval(intervals, 4) != intervals.cend());
-            assert(FindInterval(intervals, 6) == intervals.cend());
-            assert(FindInterval(intervals, 8) != intervals.cend());
-            assert(FindInterval(intervals, 1) == intervals.cend());
+            assert(FindIntervalInDisjointSet(intervals, 4) != intervals.cend());
+            assert(FindIntervalInDisjointSet(intervals, 6) == intervals.cend());
+            assert(FindIntervalInDisjointSet(intervals, 8) != intervals.cend());
+            assert(FindIntervalInDisjointSet(intervals, 1) == intervals.cend());
 
             auto high = Infimum(intervals, 20);
             assert(high->first == 8);
@@ -725,6 +835,23 @@ namespace AZ::Tests
 
         assert(IsIn("hibou", std::initializer_list<const char*>{ "chouette", "hibou", "jay" }));
         assert(!IsIn("hibou", std::initializer_list<const char*>{ "chouette", "jay" }));
+
+        Interval<int> intvs[] = {{0,10}, {1,5}, {3,3}, {7,9}, {12,15}};
+        IntervalCollection<int> ic;
+        std::for_each(std::begin(intvs), std::end(intvs), [&](auto i) {ic.Add(i); });
+        ic.Seal();
+        assert(ic.GetClosestIntervalSurrounding(-3).IsEmpty());
+        assert((ic.GetClosestIntervalSurrounding(0) == Interval<int>{0,10}));
+        assert((ic.GetClosestIntervalSurrounding(1) == Interval<int>{1,5}));
+        assert((ic.GetClosestIntervalSurrounding(3) == Interval<int>{3,3}));
+        assert((ic.GetClosestIntervalSurrounding(4) == Interval<int>{1,5}));
+        assert((ic.GetClosestIntervalSurrounding(6) == Interval<int>{0,10}));
+        assert((ic.GetClosestIntervalSurrounding(5) == Interval<int>{1,5}));
+        assert((ic.GetClosestIntervalSurrounding(7) == Interval<int>{7,9}));
+        assert((ic.GetClosestIntervalSurrounding(9) == Interval<int>{7,9}));
+        assert(ic.GetClosestIntervalSurrounding(11).IsEmpty());
+        assert((ic.GetClosestIntervalSurrounding(13) == Interval<int>{12,15}));
+        assert(ic.GetClosestIntervalSurrounding(16).IsEmpty());
     }
 }
 #endif

+ 4 - 0
src/MetaUtils.h

@@ -366,6 +366,10 @@ namespace AZ
 
     template <class Default, template<class...> class Op, class... Args>
     using DetectedOr_t = typename DetectedOr<Default, Op, Args...>::type;
+
+    //! define a Pair typealias that has two same T, without the need to repeat yourself
+    template <typename T>
+    using Pair = std::pair<T, T>;
 }
 
 #ifndef NDEBUG

+ 1 - 1
src/PadToAttributeMutator.cpp

@@ -354,7 +354,7 @@ namespace AZ::ShaderCompiler
             else if (varInfo.GetTypeClass() == TypeClass::Enum)
             {
                 auto* asClassInfo = m_ir.GetSymbolSubAs<ClassInfo>(varInfo.GetTypeId().GetName());
-                size = asClassInfo->Get<EnumerationInfo>()->m_underlyingType.m_arithmeticInfo.GetBaseSize();
+                size = asClassInfo->Get<EnumerationInfo>()->m_underlyingType.m_arithmeticInfo.m_baseSize;
             }
 
             offset = Packing::PackNextChunk(layoutPacking, size, startAt);

+ 27 - 0
tests/Advanced/mae-methodcall.azsl

@@ -0,0 +1,27 @@
+ShaderResourceGroupSemantic slot0
+{
+    FrequencyId = 1;
+    ShaderVariantFallback = 128;
+};
+ShaderResourceGroup srg0 : slot0{}
+
+class C
+{
+    void f(double) { f(0) + f(0) * f(0) / f(0); f(0) - f(0) % f(0); }  // cost 7*7+2
+    void f(int) { ;;;;;;; }  // cost 7
+};
+
+option bool o;
+
+float4 main()
+{
+    if (o)  // unnamed block $bk0
+    {
+        C c;
+        {   // unnamed block $bk1 to verify lookup capability to find `/main/$bk0/c` from `/main/$bk0/$bk1/`
+            // understand that `c`'s type is `/C`, and use /C scope to lookup the f() method.
+            // also deep expression on LHS of MAE to give no break to the typeof system
+            (c).f(2 * 5.0l);   // double promotion in binary expression that resolves to double overload method call
+        }
+    }
+}

+ 46 - 0
tests/Advanced/mae-methodcall.py

@@ -0,0 +1,46 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+"""
+Copyright (c) Contributors to the Open 3D Engine Project.
+For complete copyright and license terms please see the LICENSE at the root of this distribution.
+
+SPDX-License-Identifier: Apache-2.0 OR MIT
+"""
+import sys
+import os
+sys.path.append("..")
+sys.path.append("../..")
+from clr import *
+import testfuncs
+
+
+def verifyOptionCosts(thefile, compilerPath, silent):
+    j, ok = testfuncs.buildAndGetJson(thefile, compilerPath, silent, ["--options"])
+    if ok:
+        predicates = []
+        # check all references of func()
+        predicates.append(lambda: j["ShaderOptions"][0]["name"] == "o")
+        predicates.append(lambda: j["ShaderOptions"][0]["costImpact"] == 54)
+
+        if not silent: print (fg.CYAN+ style.BRIGHT+ "option expected cost check..."+ style.RESET_ALL)
+        ok = testfuncs.verifyAllPredicates(predicates, j)
+    return ok
+
+result = 0  # to define for sub-tests
+resultFailed = 0
+
+def doTests(compiler, silent, azdxcpath):
+    global result
+    global resultFailed
+
+    # Working directory should have been set to this script's directory by the calling parent
+    # You can get it once doTests() is called, but not during initialization of the module,
+    #  because at that time it will still be set to the working directory of the calling script
+    workDir = os.getcwd()
+
+    if verifyOptionCosts(os.path.join(workDir, "mae-methodcall.azsl"), compiler, silent): result += 1
+    else: resultFailed += 1
+
+
+if __name__ == "__main__":
+    print ("please call from testapp.py")

+ 3 - 2
tests/Semantic/AsError/overload-resolution-impossible-and-heteroreturn.azsl

@@ -2,12 +2,13 @@ struct A {};
 struct B {};
 
 A make(int);
-B make(uint);
+B make(float);
 
 void main()
 {
     float x = 0.5;
-    A a = make((int)floor(x) + 1);  // #EC 41
+    A a = make(floor(x));  // #EC 41
+    //    make((int)floor(x));  // help azslc knowing about unregistered functions by casting to force the type.
 }
 /*Semantic Error 41: line 10::14 '(10): unable to match arguments (<fail>) to a registered overload. candidates are:
 /make(?int)

+ 110 - 6
tests/Semantic/typeof-keyword.azsl

@@ -19,7 +19,7 @@ top gettop();
 
 class A
 {
-	int a;
+    int a;
 };
 
 class B : A
@@ -83,16 +83,14 @@ void h()
     //   NumericConstructorExpression  float2(0,0)
     //   LiteralExpression             42
     //   CommaExpression               X, Y
-    // not supported:
     //   PostfixUnaryExpression        i++
     //   PrefixUnaryExpression         ++i
     //   BinaryExpression              i + j
-    // e.g. typeof(1 + 3) = <fail>
 
-     // mathematics
+    // mathematics
     __azslc_print_message("@check predicate ");
     __azslc_print_symbol(typeof(1 + 3), __azslc_prtsym_least_qualified);
-    __azslc_print_message(" == '<fail>'\n");
+    __azslc_print_message(" == 'int'\n");
 
     // literals
     __azslc_print_message("@check predicate ");
@@ -186,7 +184,7 @@ void h()
     __azslc_print_symbol(typeof(top::inner), __azslc_prtsym_fully_qualified);
     __azslc_print_message(" == '/top/inner'\n");
 
-	// class inheritance: parent member access using scope resolution operator
+    // class inheritance: parent member access using scope resolution operator
     __azslc_print_message("@check predicate ");
     __azslc_print_symbol(typeof(B::a), __azslc_prtsym_least_qualified);
     __azslc_print_message(" == 'int'\n");
@@ -445,4 +443,110 @@ void h()
     __azslc_print_message("@check predicate ");
     __azslc_print_symbol(typeof(INTVAR, INTVAR, DOUBLEVAR), __azslc_prtsym_least_qualified);
     __azslc_print_message(" == 'double'\n");
+
+    // binary arithmetic expression
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof(1 + 1), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'int'\n");
+
+    // with float promotion
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof(1 + 1.f), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'float'\n");
+
+    // with half promotion
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof(1 + 1.h), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'half'\n");
+
+    // half and float to float promotion
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof(1.f + 1.h), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'float'\n");
+
+    // int16_t to double promotion
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof(1.l + int16_t(1)), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'double'\n");
+
+    // bool binary
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof(1.l && int16_t(1)), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'bool'\n");
+
+    // lookedup
+    double d;
+    int64_t i64;
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof(d || i64), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'bool'\n");
+
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof(d - i64), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'double'\n");
+
+    // vector scalar
+    float4 f4;
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof(2.f * f4), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'float4'\n");
+
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof(f4 * 2.f), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'float4'\n");
+
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof(d * f4), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'double4'\n");
+
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof(f4 * d), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'double4'\n");
+
+    // matrix scalar
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof(float() * float3x2()), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'float3x2'\n");
+
+    // matrix scalar with base type promotion
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof(half3x2() * double()), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'double3x2'\n");
+
+    // truncations
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof(float3() * float2()), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'float2'\n");
+
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof(float4x4() * float2x2()), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'float2x2'\n");
+
+    // truncate & upcast
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof(float4x4() * double2x2()), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'double2x2'\n");
+
+    // through alias
+    typealias d34m = double3x4;
+    typealias real = half;
+
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof((d34m)0 * (real)0), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'double3x4'\n");
+    // note: the parser takes d32m() as a function call, this is the "most verxing parse" problem
+    //       float() is understood by the parser as a NumericConstructorExpression because it
+    //       has a list of the tokens representing all fundamental types.
+    //       but, with user defined identifiers, it can't branch into the "intended" context,
+    //       because ALL(*) Antlr4 parsers recognizes context free grammar only.
+    //       We could adopt a universal construction syntax Obj{} a la C++ for AZSL but it's not
+    //       compliant with the philosophy of not deviating from HLSL, which on its side doesn't really
+    //       accept constructor-type constructs. DXC just does it better here because clang parser is Turing complete.
+
+    d34m d34var;
+    real rVar;
+
+    __azslc_print_message("@check predicate ");
+    __azslc_print_symbol(typeof(rVar - d34var ), __azslc_prtsym_least_qualified);
+    __azslc_print_message(" == 'double3x4'\n");
 }