Browse Source

Cleaner implementation of binary operation type detection.
Principle: diminish code paths at maximum by reducing the conditions checked to an absolute minimalist level.

Signed-off-by: Vivien Oddou <[email protected]>

Vivien Oddou 2 years ago
parent
commit
878c89c9af
4 changed files with 64 additions and 60 deletions
  1. 11 29
      src/AzslcSemanticOrchestrator.cpp
  2. 40 31
      src/AzslcTypes.h
  3. 9 0
      src/GenericUtils.h
  4. 4 0
      src/MetaUtils.h

+ 11 - 29
src/AzslcSemanticOrchestrator.cpp

@@ -1205,37 +1205,19 @@ namespace AZ::ShaderCompiler
         TypeRefInfo typeInfoRhs = CreateTypeRefInfo(UnqualifiedNameView{rhs});
         if (typeInfoLhs.m_arithmeticInfo.IsEmpty() || typeInfoRhs.m_arithmeticInfo.IsEmpty())
         {   // Case that shouldn't work in AZSL yet (but may work in HLSL2021)
-            // -> UDT operator (would need support of operator overloading).
-            // We assume type is type of left expression.
-            // (what we really need to do is go get the return type of the overloaded operator)
+            //  -> UDT operator (would need support of operator overloading).
+            // We arbitrarily assume a result "type of left expression".
+            // (what we would really need to do is go get the return type of the overloaded operator)
             return lhs;
-        }
-        // After this is, both sides are arithmetic types (scalar, vector, matrix).
-        // matrix op vector is a forbidden case,
-        //  e.g it won't do Y=MX (m*v->v), nor dotproduct-ing vectors for that matter (v*v->scalar)
-        // It will do component to component op and return more or less the same type.
-        // In case of dimension differences, it will truncate to smaller type
-        //  e.g float2 + float3 results in float2 with .z lost in implicit cast
+        } // After this `if`, both sides are arithmetic types (scalar, vector, matrix).
+        // "matrix op vector" (or commutated) is a forbidden case,
+        //   e.g it won't do Y=MX (m*v->v), nor dotproduct-ing vectors for that matter (v*v->scalar)
+        // It will do piecewise `op` and return more or less the same type as the operands.
+        // In case of dimension differences, it will truncate to the smaller type
+        //   e.g float2 + float3 results in float2 with .z lost in implicit cast
         //      same for float2x3 * float2x2 (results in float2x2)
-        // We assume that for * + - / % ^ | & << >>
-        bool lhsIsVecMat = typeInfoLhs.m_typeClass.IsOneOf(TypeClass::Vector, TypeClass::Matrix);
-        bool rhsIsVecMat = typeInfoRhs.m_typeClass.IsOneOf(TypeClass::Vector, TypeClass::Matrix);
-        if (lhsIsVecMat !=/*xor*/ rhsIsVecMat)
-        {
-            auto scalarOperand = lhsIsVecMat ? typeInfoRhs : typeInfoLhs;
-            auto vecmatOperand = lhsIsVecMat ? typeInfoLhs : typeInfoRhs;
-            // typeof(vecmat op scalar)->promoted(vecmat)
-            return vecmatOperand.m_arithmeticInfo.PromoteTruncateWith(scalarOperand.m_arithmeticInfo).GenTypeId();
-        }
-        //else if (lhsIsVecMat && rhsIsVecMat)
-        {
-            // typeof(vecmat op vecmat)->promoted(truncated(vecmat))
-            return typeInfoLhs.m_arithmeticInfo.PromoteTruncateWith(typeInfoRhs.m_arithmeticInfo).GenTypeId();
-        }
-        // case left: both sides are scalar.
-        // final logic in case of arithmetic type class: integer/float promotion.
-        //return typeInfoLhs.m_arithmeticInfo.m_conversionRank > typeInfoRhs.m_arithmeticInfo.m_conversionRank ?
-        //    lhs : rhs;
+        // We assume that for all non bool ops: * + - / % ^ | & << >>
+        return PromoteTruncateWith({typeInfoLhs.m_arithmeticInfo, typeInfoRhs.m_arithmeticInfo}).GenTypeId();
     }
 
     QualifiedName SemanticOrchestrator::TypeofExpr(azslParser::ExpressionExtContext* ctx) const

+ 40 - 31
src/AzslcTypes.h

@@ -323,13 +323,13 @@ namespace AZ::ShaderCompiler
         }
 
         //! Get the size of the whole type considering dimensions
-        const uint32_t GetTotalSize() const
+        uint32_t GetTotalSize() const
         {
             return m_baseSize * (m_cols > 0 ? m_cols : 1) * (m_rows > 0 ? m_rows : 1);
         }
 
         //! True if the type is a vector type. If it's a vector type it cannot be a matrix as well.
-        const bool IsVector() const
+        bool IsVector() const
         {
             // This treats special cases like 2x1, 3x1 and 4x1 as vectors
             // The behavior is consistent with dxc packing rules
@@ -337,7 +337,7 @@ namespace AZ::ShaderCompiler
         }
 
         //! True if the type is a matrix type. If it's a matrix type it cannot be a vector as well.
-        const bool IsMatrix() const
+        bool IsMatrix() const
         {
             // This fails special cases like 2x1, 3x1 and 4x1,
             //   but allows cases like 1x2, 1x3 and 1x4.
@@ -345,8 +345,13 @@ namespace AZ::ShaderCompiler
             return m_rows > 0 && m_cols > 1;
         }
 
-        //! If initialized as a fundamental -> not empty.
-        const bool IsEmpty() const
+        bool IsScalar() const
+        {
+            return m_rows <= 1 && m_cols <= 1;  // float, float1, float1x1 are scalars.
+        }
+
+        //! Non-created state
+        bool IsEmpty() const
         {
             return m_underlyingScalar == -1;
         }
@@ -375,32 +380,6 @@ namespace AZ::ShaderCompiler
             }
         }
 
-        //! Create a new arithmetic traits promoted by necessity (through a binary operation usually)
-        //! of compatibility with a second operand of arithmetic typeclass.
-        //! For example: type(half{} + int{})->half
-        //!              type(float3x3{} * double{})->double3x3
-        //! And with columns & rows truncated to the smallest operand,
-        //!   as part of the implicit necessary cast for operation compatibility.
-        ArithmeticTraits PromoteTruncateWith(const ArithmeticTraits& secondOperand) const
-        {
-            ArithmeticTraits copy{*this};
-            // The higher ranking underlying wins independently of global object size
-            if (secondOperand.m_conversionRank > m_conversionRank)
-            {
-                copy.m_underlyingScalar = secondOperand.m_underlyingScalar;
-            }
-            if (secondOperand.m_rows > 0 && m_rows > 0)
-            {
-                copy.m_rows = std::min(m_rows, secondOperand.m_rows);
-            }
-            if (secondOperand.m_cols > 0 && m_cols > 0)
-            {
-                copy.m_cols = std::min(m_cols, secondOperand.m_cols);
-            }
-            copy.ResolveBaseSizeAndRank();
-            return copy;
-        }
-
         uint32_t m_baseSize = 0;                 //< In bytes. Size of 0 indicates TypeRefInfo which hasn't been resolved or is a struct
         uint32_t m_rows = 0;                     //< 0 means it's not a matrix (effective Rows = 1). 1 or more means a Matrix
         uint32_t m_cols = 0;                     //< 0 means it's not a vector (effective Cols = 1). 1 or more means a Vector or Matrix
@@ -518,6 +497,36 @@ namespace AZ::ShaderCompiler
         return toReturn;
     }
 
+    //! Create a new arithmetic traits promoted by necessity (through a binary operation usually)
+    //! of compatibility with a second operand of arithmetic typeclass.
+    //! For example: type(half{} + int{})->half
+    //!              type(float3x3{} * double{})->double3x3
+    //! And with columns & rows truncated to the smallest operand,
+    //!   as part of the implicit necessary cast for operation compatibility.
+    inline ArithmeticTraits PromoteTruncateWith(Pair<ArithmeticTraits> operands)
+    {
+        auto [_1, _2] = operands;
+        // put the scalar last (will be useful later if there is only one scalar operand)
+        SwapIf(_1, _2, _1.IsScalar());
+        // now, let's construct the result in _1
+
+        if (_2.m_conversionRank > _1.m_conversionRank) // the higher ranking underlying wins; independently of full object size
+        {
+            _1.m_underlyingScalar = _2.m_underlyingScalar;
+        }
+        // cases: scalar-scalar : no dim change
+        //        vecmat-scalar : no dim change, since result is dim(vecmat) which is in _1
+        //        scalar-vecmat : impossible (sorted by swap above)
+        //        vecmat-vecmat : min(_1,_2)
+        if (!_1.IsScalar() && !_2.IsScalar())
+        {
+            _1.m_rows = std::min(_1.m_rows, _2.m_rows);
+            _1.m_cols = std::min(_1.m_cols, _2.m_cols);
+        }
+        _1.ResolveBaseSizeAndRank();
+        return _1;
+    }
+
     MAKE_REFLECTABLE_ENUM(RootParamType,
         SRV,               // t
         UAV,               // u

+ 9 - 0
src/GenericUtils.h

@@ -624,6 +624,15 @@ namespace AZ
         RemoveDuplicatesKeepOrder(lhs);
     }
 
+    //! Conditional swap algorithm
+    template<typename T>
+    void SwapIf(T&& a, T&& b, bool condition)
+    {
+        if (condition)
+        {
+            std::swap(std::forward<T>(a), std::forward<T>(b));
+        }
+    }
 }
 
 #ifndef NDEBUG

+ 4 - 0
src/MetaUtils.h

@@ -366,6 +366,10 @@ namespace AZ
 
     template <class Default, template<class...> class Op, class... Args>
     using DetectedOr_t = typename DetectedOr<Default, Op, Args...>::type;
+
+    //! define a Pair typealias that has two same T, without the need to repeat yourself
+    template <typename T>
+    using Pair = std::pair<T, T>;
 }
 
 #ifndef NDEBUG