Przeglądaj źródła

First working proof of concept of option rank cost static analyzer.

Signed-off-by: Vivien Oddou <[email protected]>
Vivien Oddou 2 lat temu
rodzic
commit
59ca8f14d3

+ 55 - 9
src/AzslcReflection.cpp

@@ -1072,17 +1072,63 @@ namespace AZ::ShaderCompiler
                     if (auto* idExpr = As<azslParser::IdentifierExpressionContext*>(callNode->Expr))
                     {
                         UnqualifiedName funcName = ExtractNameFromIdExpression(idExpr->idExpression());
-                        m_ir->m_sema.ResolveOverload(
-                        IdAndKind* symbolMeantUnderCallNode = m_ir->m_symbols.LookupSymbol(encloser.GetName(), funcName);
-                        auto* funcInfo = symbolMeantUnderCallNode->second.GetSubAs<FunctionInfo>();
-                        if (funcInfo->m_costScore == -1)
+                        IdAndKind* overload = m_ir->m_symbols.LookupSymbol(encloser.GetName(), funcName);
+                        if (!overload) // in case of function not found, we assume it's an intrinsic.
                         {
-                            funcInfo->m_costScore = 0;
-                            using AstFDef = azslParser::HlslFunctionDefinitionContext;
-                            AnalyzeImpact(polymorphic_downcast<AstFDef*>(funcInfo->m_defNode->parent)->block(),
-                                          funcInfo->m_costScore);  // recurse and cache if not already done
+                            if (IsOneOf(funcName, "CallShader", "TraceRay"))
+                            { // non measurable but assumed high
+                                scoreAccumulator += 50;
+                            }
+                            else if (IsOneOf(funcName, "InterlockedCompareStore", "InterlockedCompareExchange", "InterlockedExchange", "Append"))
+                            { // hardware locked memory ops, high weight
+                                scoreAccumulator += 10;
+                            }
+                            else if (IsOneOf(funcName, "Sample", "Load"))
+                            { // memory access is weighted in between
+                                scoreAccumulator += 5;
+                            }
+                            else
+                            { // unlisted intrinsics like lerp, log2, cos, distance.. will default to a cost of 1.
+                                scoreAccumulator += 1;
+                            }
+                        }
+                        else
+                        {
+                            azslParser::ArgumentListContext* args = GetArgumentListIfBelongsToFunctionCall(callNode);
+                            IdAndKind* symbolMeantUnderCallNode = m_ir->m_sema.ResolveOverload(overload, args);
+                            IdentifierUID concrete;
+                            if (!symbolMeantUnderCallNode || m_ir->GetKind(symbolMeantUnderCallNode->first) == Kind::OverloadSet)
+                            { // in case of strict selection failure, run a fuzzy select
+                                size_t numArgs = NumArgs(callNode);
+                                overload->second.GetSubAs<OverloadSetInfo>()->AnyOf(
+                                    [&](IdentifierUID const& uid)
+                                    {
+                                        auto* concreteFcInfo = m_ir->GetSymbolSubAs<FunctionInfo>(uid.GetName());
+                                        size_t numParams = concreteFcInfo->GetParameters(true).size();
+                                        if (numParams == numArgs)
+                                        {
+                                            concrete = uid;
+                                            return true;
+                                        }
+                                        return false;
+                                    }
+                                );
+                                // if still not enough to get a fix, it might be an ill-formed input. prefer to forfeit
+                            }
+                            else
+                            {
+                                concrete = symbolMeantUnderCallNode->first;
+                            }
+                            auto* funcInfo = m_ir->GetSymbolSubAs<FunctionInfo>(concrete.GetName());
+                            if (funcInfo->m_costScore == -1)
+                            {
+                                funcInfo->m_costScore = 0;
+                                using AstFDef = azslParser::HlslFunctionDefinitionContext;
+                                AnalyzeImpact(polymorphic_downcast<AstFDef*>(funcInfo->m_defNode->parent)->block(),
+                                              funcInfo->m_costScore);  // recurse and cache if not already done
+                            }
+                            scoreAccumulator += funcInfo->m_costScore;
                         }
-                        scoreAccumulator += funcInfo->m_costScore;
                     }
                     // other cases forfeited for now, but that would at least be braces (f)() or MAE x.m()
                 }

+ 28 - 14
src/AzslcSemanticOrchestrator.cpp

@@ -1166,7 +1166,9 @@ namespace AZ::ShaderCompiler
                                      As<azslParser::AssignmentExpressionContext*>(ctx),
                                      As<azslParser::NumericConstructorExpressionContext*>(ctx),
                                      As<azslParser::LiteralExpressionContext*>(ctx),
-                                     As<azslParser::LiteralContext*>(ctx));
+                                     As<azslParser::LiteralContext*>(ctx),
+                                     As<azslParser::PrefixUnaryExpressionContext*>(ctx),
+                                     As<azslParser::PostfixUnaryExpressionContext*>(ctx));
         }
         catch (AllNull&)
         {
@@ -1175,6 +1177,24 @@ namespace AZ::ShaderCompiler
         }
     }
 
+    QualifiedName SemanticOrchestrator::TypeofExpr(azslParser::PrefixUnaryExpressionContext* ctx) const
+    {
+        // among all possibilities Plus|Minus|Not|Tilde|PlusPlus|MinusMinus
+        // only "Not" returns a bool, the rest is transparent and returns the same type as rhs
+        return ctx->prefixUnaryOperator()->start->getType() == azslLexer::Not ? MangleScalarType("bool")
+            : TypeofExpr(ctx->Expr);
+    }
+
+    QualifiedName SemanticOrchestrator::TypeofExpr(azslParser::PostfixUnaryExpressionContext* ctx) const
+    {
+        return TypeofExpr(ctx->Expr); // in case of x++ or x-- the type is the type of x.
+    }
+
+    QualifiedName SemanticOrchestrator::TypeofExpr(azslParser::BinaryExpressionContext* ctx) const
+    {
+        return TypeofExpr(ctx->Expr);
+    }
+
     QualifiedName SemanticOrchestrator::TypeofExpr(azslParser::ExpressionExtContext* ctx) const
     {
         return VisitFirstNonNull([this](auto* ctx) { return TypeofExpr(ctx); },
@@ -1303,29 +1323,23 @@ namespace AZ::ShaderCompiler
 
     QualifiedName SemanticOrchestrator::TypeofExpr(azslParser::LiteralContext* ctx) const
     {
-        // verifies that our hardcoded strings don't have typo, by checking against the lexer-extracted keywords stored in the Scalar array.
-        auto checkExistType = [](string_view scalarName){return std::find(AZ::ShaderCompiler::Predefined::Scalar.begin(), AZ::ShaderCompiler::Predefined::Scalar.end(), scalarName) != AZ::ShaderCompiler::Predefined::Scalar.end();};
         // verifies that last or 1-before-last characters are a particular literal suffix. like in "56ul"
         auto hasSuffix = [](auto node, char s){return tolower(node->getText().back()) == s || tolower(Slice(node->getText(), -3, -2) == s);};
-        auto checkAndReturn = [&](string_view typeName)
-                              {
-                                  assert(checkExistType(typeName));
-                                  return QualifiedName{"?" + string{typeName}};
-                              };
+
         if (ctx->True() || ctx->False())
         {
-            return checkAndReturn("bool");
+            return MangleScalarType("bool");
         }
         else if (auto* literal = ctx->IntegerLiteral())
         {
-            return hasSuffix(literal, 'u') ? checkAndReturn("uint")
-                 : checkAndReturn("int");
+            return hasSuffix(literal, 'u') ? MangleScalarType("uint")
+                 : MangleScalarType("int");
         }
         else if (auto* literal = ctx->FloatLiteral())
         {
-            return hasSuffix(literal, 'h') ? checkAndReturn("half")
-                 : hasSuffix(literal, 'l') ? checkAndReturn("double")
-                 : checkAndReturn("float");
+            return hasSuffix(literal, 'h') ? MangleScalarType("half")
+                 : hasSuffix(literal, 'l') ? MangleScalarType("double")
+                 : MangleScalarType("float");
         }
         return {"<fail>"};
     }

+ 4 - 1
src/AzslcSemanticOrchestrator.h

@@ -222,6 +222,9 @@ namespace AZ::ShaderCompiler
         auto TypeofExpr(azslParser::LiteralExpressionContext* ctx) const -> QualifiedName;
         auto TypeofExpr(azslParser::LiteralContext* ctx) const -> QualifiedName;
         auto TypeofExpr(azslParser::CommaExpressionContext* ctx) const -> QualifiedName;
+        auto TypeofExpr(azslParser::PostfixUnaryExpressionContext* ctx) const -> QualifiedName;
+        auto TypeofExpr(azslParser::PrefixUnaryExpressionContext* ctx) const -> QualifiedName;
+        auto TypeofExpr(azslParser::BinaryExpressionContext* ctx) const -> QualifiedName;
 
         //! Parse the AST from a variable declaration and attempt to extract array dimensions integer constants [dim1][dim2]...
         //! Return: <true> on success, <false> otherwise
@@ -330,7 +333,7 @@ namespace AZ::ShaderCompiler
         {
             auto typeId      = LookupType(typeNameOrCtx, policy);
             auto tClass      = GetTypeClass(typeId);
-            auto arithmetic  = IsNonGenericArithmetic(tClass) ? CreateArithmeticTypeInfo(typeId.GetName()) : ArithmeticTypeInfo{}; // TODO: canonicalize generics
+            auto arithmetic  = IsNonGenericArithmetic(tClass) ? CreateArithmeticTraits(typeId.GetName()) : ArithmeticTraits{}; // TODO: canonicalize generics
             return TypeRefInfo{typeId, arithmetic, tClass};
         }
 

+ 1 - 1
src/AzslcSymbolAggregator.cpp

@@ -20,7 +20,7 @@ namespace AZ::ShaderCompiler
             auto& [uid, kindInfo] = st.AddIdentifier(azirName, Kind::Type);  // the kind is Type because all predefined are stored as such.
             auto& typeInfo        = kindInfo.GetSubAfterInitAs<Kind::Type>();
             auto typeClass        = TypeClass::FromStr(bag.m_name);
-            auto arithmetic       = IsNonGenericArithmetic(typeClass) ? CreateArithmeticTypeInfo(azirName) : ArithmeticTypeInfo{}; // TODO: canonicalize generics
+            auto arithmetic       = IsNonGenericArithmetic(typeClass) ? CreateArithmeticTraits(azirName) : ArithmeticTraits{}; // TODO: canonicalize generics
             typeInfo = TypeRefInfo{uid, arithmetic, typeClass};
         }
     }

+ 34 - 21
src/AzslcTypes.h

@@ -260,24 +260,36 @@ namespace AZ::ShaderCompiler
         return result;
     }
 
+    /// Verifies that our hardcoded strings don't have typo, by checking against the lexer-extracted keywords stored in the Scalar array.
+    inline bool CheckExistScalarType(string_view scalarName)
+    {
+        return std::find(Predefined::Scalar.begin(),
+                         Predefined::Scalar.end(),
+                         scalarName) != Predefined::Scalar.end();
+    };
+
+    /// Assert validity of the type string, and form a "?<type>" tainted string to host a scalar type
+    inline QualifiedName MangleScalarType(string_view typeName)
+    {
+        assert(CheckExistScalarType(typeName));
+        return QualifiedName{"?" + string{typeName}};
+    };
+
     /// Rows and Cols (this is specific to shader languages to identify vector and matrix types)
-    struct ArithmeticTypeInfo
+    struct ArithmeticTraits
     {
-        void ResolveSize()
+        void ResolveBaseSizeAndRank()
         {
-            m_size = Packing::PackedSizeof(m_underlyingScalar);
-        }
+            m_baseSize = Packing::PackedSizeof(m_underlyingScalar);
 
-        /// Get the size of a single base element
-        const uint32_t GetBaseSize() const
-        {
-            return m_size;
+            
+            
         }
 
-                /// Get the size of a the element with regard to dimensions as well
+        /// Get the size of the whole type considering dimensions
         const uint32_t GetTotalSize() const
         {
-            return m_size * (m_cols > 0 ? m_cols : 1) * (m_rows > 0 ? m_rows : 1);
+            return m_baseSize * (m_cols > 0 ? m_cols : 1) * (m_rows > 0 ? m_rows : 1);
         }
 
         /// True if the type is a vector type. If it's a vector type it cannot be a matrix as well.
@@ -310,10 +322,11 @@ namespace AZ::ShaderCompiler
                 AZ::ShaderCompiler::Predefined::Scalar[m_underlyingScalar] : "<NA>";
         }
 
-        uint32_t m_size = 0;                     // In bytes. Size of 0 indicates TypeRefInfo which hasn't been resolved or is a struct
-        uint32_t m_rows = 0;                     // 0 means it's not a matrix (effective Rows = 1). 1 or more means a Matrix
-        uint32_t m_cols = 0;                     // 0 means it's not a vector (effective Cols = 1). 1 or more means a Vector or Matrix
-        int      m_underlyingScalar = -1;        // index into AZ::ShaderCompiler::Predefined::Scalar, all fundamentals end up in a scalar at its leaf.
+        uint32_t m_baseSize = 0;                 //< In bytes. Size of 0 indicates TypeRefInfo which hasn't been resolved or is a struct
+        uint32_t m_rows = 0;                     //< 0 means it's not a matrix (effective Rows = 1). 1 or more means a Matrix
+        uint32_t m_cols = 0;                     //< 0 means it's not a vector (effective Cols = 1). 1 or more means a Vector or Matrix
+        int      m_conversionRank = 0;           //< Used in conversions and promotions
+        int      m_underlyingScalar = -1;        //< Index into AZ::ShaderCompiler::Predefined::Scalar, all fundamentals end up in a scalar at its leaf.
     };
 
     //! TypeRefInfo holds resolved immutable information of a core type (the `matrix2x2` in `column_major matrix2x2 a[3];`)
@@ -323,7 +336,7 @@ namespace AZ::ShaderCompiler
     struct TypeRefInfo
     {
         TypeRefInfo() = default;
-        TypeRefInfo(IdentifierUID typeId, const ArithmeticTypeInfo& fundamentalInfo, TypeClass typeClass)
+        TypeRefInfo(IdentifierUID typeId, const ArithmeticTraits& fundamentalInfo, TypeClass typeClass)
             : m_arithmeticInfo{fundamentalInfo},
               m_typeClass{typeClass},
               m_typeId{typeId}
@@ -359,19 +372,19 @@ namespace AZ::ShaderCompiler
             return !operator==(lhs,rhs);
         }
 
-        IdentifierUID       m_typeId;
-        TypeClass           m_typeClass;
-        ArithmeticTypeInfo  m_arithmeticInfo;
+        IdentifierUID     m_typeId;
+        TypeClass         m_typeClass;
+        ArithmeticTraits  m_arithmeticInfo;
     };
 
     //! Run a syntactic analysis of an arithmetic type name and extract info on its composition
-    inline ArithmeticTypeInfo CreateArithmeticTypeInfo(QualifiedName a_typeName)
+    inline ArithmeticTraits CreateArithmeticTraits(QualifiedName a_typeName)
     {
         assert(IsArithmetic( /*slow*/AnalyzeTypeClass(TentativeName{a_typeName}) ));  // no need to call this function if you don't have a fundamental (non void)
         assert(!IsGenericArithmetic( /*slow*/AnalyzeTypeClass(TentativeName{a_typeName}) ));
         // ↑ fatal aspect. The input needs to be canonicalized earlier to minimize this function's complexity.
 
-        ArithmeticTypeInfo toReturn;
+        ArithmeticTraits toReturn;
 
         string typeName = UnMangle(a_typeName);
         size_t baseTypeLen = typeName.length();
@@ -422,7 +435,7 @@ namespace AZ::ShaderCompiler
         auto it = ::std::find(AZ::ShaderCompiler::Predefined::Scalar.begin(), AZ::ShaderCompiler::Predefined::Scalar.end(), baseType);
         assert(it != AZ::ShaderCompiler::Predefined::Scalar.end()); // baseType must exist in the Scalar bag by program invariant.
         toReturn.m_underlyingScalar = static_cast<int>( std::distance(AZ::ShaderCompiler::Predefined::Scalar.begin(), it) );
-        toReturn.ResolveSize();
+        toReturn.ResolveBaseSizeAndRank();
         return toReturn;
     }
 

+ 7 - 0
src/AzslcUtils.h

@@ -1047,6 +1047,13 @@ namespace AZ::ShaderCompiler
         return found ? found->argumentList() : nullptr;
     }
 
+    //! access the argument count at a function call site (from the AST)
+    inline size_t NumArgs(azslParser::FunctionCallExpressionContext* callCtx)
+    {
+        azslParser::ArgumentsContext* argsNode = callCtx->argumentList()->arguments();
+        return argsNode ? argsNode->expression().size() : 0;
+    }
+
     //! try to find a specific context type that this context would be a child of.
     template <typename LookedUp>
     inline LookedUp* ExtractSpecificParent(antlr4::ParserRuleContext* ctx)

+ 4 - 4
src/GenericUtils.h

@@ -379,14 +379,14 @@ namespace AZ
     // Is-One-Of will check if a variable is equal to any of the values listed on the other parameters.
     // Example: IsOneOf(variableKind, Function, Enumeration) is short for: variableKind == Function || variableKind == Enumeration.
     // 2 arguments count: recursion terminal overload.
-    template <typename T>
-    bool IsOneOf(T value, T tocheck)
+    template <typename T, typename U>
+    bool IsOneOf(T value, U tocheck)
     {
         return value == tocheck;
     }
     // Any argument count version
-    template <typename T, typename... Args>
-    bool IsOneOf(T value, T khead, Args... tail)
+    template <typename T, typename U, typename... Args>
+    bool IsOneOf(T value, U khead, Args... tail)
     {
         return value == khead || IsOneOf(value, tail...);
     }