瀏覽代碼

WIP: satisfactory working order for vector and matrix arithmetic type deduction in binary operations

Signed-off-by: Vivien Oddou <[email protected]>
Vivien Oddou 2 年之前
父節點
當前提交
31b44baad2
共有 2 個文件被更改,包括 78 次插入11 次删除
  1. 28 3
      src/AzslcSemanticOrchestrator.cpp
  2. 50 8
      src/AzslcTypes.h

+ 28 - 3
src/AzslcSemanticOrchestrator.cpp

@@ -1205,12 +1205,37 @@ namespace AZ::ShaderCompiler
         TypeRefInfo typeInfoRhs = CreateTypeRefInfo(UnqualifiedNameView{rhs});
         if (typeInfoLhs.m_arithmeticInfo.IsEmpty() || typeInfoRhs.m_arithmeticInfo.IsEmpty())
         {   // Case that shouldn't work in AZSL yet (but may work in HLSL2021)
-            // -> UDT operator. need operator overloading. We assume type is type of left expression.
+            // -> UDT operator (would need support of operator overloading).
+            // We assume type is type of left expression.
+            // (what we really need to do is go get the return type of the overloaded operator)
             return lhs;
         }
+        // After this is, both sides are arithmetic types (scalar, vector, matrix).
+        // matrix op vector is a forbidden case,
+        //  e.g it won't do Y=MX (m*v->v), nor dotproduct-ing vectors for that matter (v*v->scalar)
+        // It will do component to component op and return more or less the same type.
+        // In case of dimension differences, it will truncate to smaller type
+        //  e.g float2 + float3 results in float2 with .z lost in implicit cast
+        //      same for float2x3 * float2x2 (results in float2x2)
+        // We assume that for * + - / % ^ | & << >>
+        bool lhsIsVecMat = typeInfoLhs.m_typeClass.IsOneOf(TypeClass::Vector, TypeClass::Matrix);
+        bool rhsIsVecMat = typeInfoRhs.m_typeClass.IsOneOf(TypeClass::Vector, TypeClass::Matrix);
+        if (lhsIsVecMat !=/*xor*/ rhsIsVecMat)
+        {
+            auto scalarOperand = lhsIsVecMat ? typeInfoRhs : typeInfoLhs;
+            auto vecmatOperand = lhsIsVecMat ? typeInfoLhs : typeInfoRhs;
+            // typeof(vecmat op scalar)->promoted(vecmat)
+            return vecmatOperand.m_arithmeticInfo.PromoteTruncateWith(scalarOperand.m_arithmeticInfo).GenTypeId();
+        }
+        //else if (lhsIsVecMat && rhsIsVecMat)
+        {
+            // typeof(vecmat op vecmat)->promoted(truncated(vecmat))
+            return typeInfoLhs.m_arithmeticInfo.PromoteTruncateWith(typeInfoRhs.m_arithmeticInfo).GenTypeId();
+        }
+        // case left: both sides are scalar.
         // final logic in case of arithmetic type class: integer/float promotion.
-        return typeInfoLhs.m_arithmeticInfo.m_conversionRank > typeInfoRhs.m_arithmeticInfo.m_conversionRank ?
-            lhs : rhs;
+        //return typeInfoLhs.m_arithmeticInfo.m_conversionRank > typeInfoRhs.m_arithmeticInfo.m_conversionRank ?
+        //    lhs : rhs;
     }
 
     QualifiedName SemanticOrchestrator::TypeofExpr(azslParser::ExpressionExtContext* ctx) const

+ 50 - 8
src/AzslcTypes.h

@@ -275,7 +275,7 @@ namespace AZ::ShaderCompiler
         return QualifiedName{"?" + string{typeName}};
     };
 
-    /// Rows and Cols (this is specific to shader languages to identify vector and matrix types)
+    //! Holds arithmetic-type-class information in small pieces (row, cols, base, size, rank...)
     struct ArithmeticTraits
     {
         void ResolveBaseSizeAndRank()
@@ -294,7 +294,6 @@ namespace AZ::ShaderCompiler
             //   - "standard" is > "extended" of same sizeof
             //   - The rank of bool is the smallest
             // That said, we will take inspiration from ASTContext::getIntegerRank of clang
-            // (which does not respect C++ visibly, since it takes bool's size into account, or has many equivalent ranks, in violation of rule 1)
             static const unordered_map<int, int> subranks =
             {
                 {getIndex("bool"), 1},
@@ -319,17 +318,17 @@ namespace AZ::ShaderCompiler
                 return isBool ? 1 : m_baseSize;
             };
             // The shift method is taken from clang, I suppose it's a multi-parameter order cramed into bits.
-            // so because 10 is the largest subrank, shift by 4 should separate sizeof space and subank space.
+            // so because 10 is the largest subrank, shift by 4 should separate sizeof space and subrank space.
             m_conversionRank = (getRankSizeof(m_underlyingScalar) << 4) + subranks.at(m_underlyingScalar);
         }
 
-        /// Get the size of the whole type considering dimensions
+        //! Get the size of the whole type considering dimensions
         const uint32_t GetTotalSize() const
         {
             return m_baseSize * (m_cols > 0 ? m_cols : 1) * (m_rows > 0 ? m_rows : 1);
         }
 
-        /// True if the type is a vector type. If it's a vector type it cannot be a matrix as well.
+        //! True if the type is a vector type. If it's a vector type it cannot be a matrix as well.
         const bool IsVector() const
         {
             // This treats special cases like 2x1, 3x1 and 4x1 as vectors
@@ -337,7 +336,7 @@ namespace AZ::ShaderCompiler
             return (m_cols == 1 && m_rows > 1) || (m_cols > 1 && m_rows == 0);
         }
 
-        /// True if the type is a matrix type. If it's a matrix type it cannot be a vector as well.
+        //! True if the type is a matrix type. If it's a matrix type it cannot be a vector as well.
         const bool IsMatrix() const
         {
             // This fails special cases like 2x1, 3x1 and 4x1,
@@ -346,19 +345,62 @@ namespace AZ::ShaderCompiler
             return m_rows > 0 && m_cols > 1;
         }
 
-        /// If initialized as a fundamental -> not empty.
+        //! If initialized as a fundamental -> not empty.
         const bool IsEmpty() const
         {
             return m_underlyingScalar == -1;
         }
 
-        // for pretty print
+        //! For pretty print
         string UnderlyingScalarToStr() const
         {
             return m_underlyingScalar >= 0 && m_underlyingScalar < AZ::ShaderCompiler::Predefined::Scalar.size() ?
                 AZ::ShaderCompiler::Predefined::Scalar[m_underlyingScalar] : "<NA>";
         }
 
+        //! Create a canonicalized mangled name that should represent the identity of this arithmetic type.
+        QualifiedName GenTypeId() const
+        {
+            if (IsMatrix())
+            {
+                return QualifiedName{MangleScalarType(UnderlyingScalarToStr()) + ToString(m_rows) + "x" + ToString(m_cols)};
+            }
+            else if (IsVector())
+            {
+                return QualifiedName{MangleScalarType(UnderlyingScalarToStr()) + (m_rows > 0 ? ToString(m_rows) : ToString(m_cols))};
+            }
+            else
+            {
+                return QualifiedName{MangleScalarType(UnderlyingScalarToStr())};
+            }
+        }
+
+        //! Create a new arithmetic traits promoted by necessity (through a binary operation usually)
+        //! of compatibility with a second operand of arithmetic typeclass.
+        //! For example: type(half{} + int{})->half
+        //!              type(float3x3{} * double{})->double3x3
+        //! And with columns & rows truncated to the smallest operand,
+        //!   as part of the implicit necessary cast for operation compatibility.
+        ArithmeticTraits PromoteTruncateWith(const ArithmeticTraits& secondOperand) const
+        {
+            ArithmeticTraits copy{*this};
+            // The higher ranking underlying wins independently of global object size
+            if (secondOperand.m_conversionRank > m_conversionRank)
+            {
+                copy.m_underlyingScalar = secondOperand.m_underlyingScalar;
+            }
+            if (secondOperand.m_rows > 0 && m_rows > 0)
+            {
+                copy.m_rows = std::min(m_rows, secondOperand.m_rows);
+            }
+            if (secondOperand.m_cols > 0 && m_cols > 0)
+            {
+                copy.m_cols = std::min(m_cols, secondOperand.m_cols);
+            }
+            copy.ResolveBaseSizeAndRank();
+            return copy;
+        }
+
         uint32_t m_baseSize = 0;                 //< In bytes. Size of 0 indicates TypeRefInfo which hasn't been resolved or is a struct
         uint32_t m_rows = 0;                     //< 0 means it's not a matrix (effective Rows = 1). 1 or more means a Matrix
         uint32_t m_cols = 0;                     //< 0 means it's not a vector (effective Cols = 1). 1 or more means a Vector or Matrix