2
0
Эх сурвалжийг харах

Matrix and vector binary operations and conversions fix (#1170)

The compiler was crashing on binary operations used on vectors and
matrices of varying sizes. This fix aligns the matrix/vector conversion
behavior and allowed combinations with FXC.

Vector initialization from matrices of equivalent dimensions now works
as well.

Fixes: #1090, #1157
Helena Kotas 7 жил өмнө
parent
commit
841e0d470b

+ 35 - 9
tools/clang/lib/Sema/SemaHLSL.cpp

@@ -3022,7 +3022,7 @@ private:
     return ImplicitCastExpr::Create(*m_context, input->getType(), CK_LValueToRValue, input, nullptr, VK_RValue);
     return ImplicitCastExpr::Create(*m_context, input->getType(), CK_LValueToRValue, input, nullptr, VK_RValue);
   }
   }
 
 
-  HRESULT CombineDimensions(QualType leftType, QualType rightType, QualType *resultType);
+  HRESULT CombineDimensions(QualType leftType, QualType rightType, ArTypeObjectKind leftKind, ArTypeObjectKind rightKind, QualType *resultType);
 
 
   clang::TypedefDecl *LookupMatrixShorthandType(HLSLScalarType scalarType, UINT rowCount, UINT colCount) {
   clang::TypedefDecl *LookupMatrixShorthandType(HLSLScalarType scalarType, UINT rowCount, UINT colCount) {
     DXASSERT_NOMSG(scalarType != HLSLScalarType::HLSLScalarType_unknown &&
     DXASSERT_NOMSG(scalarType != HLSLScalarType::HLSLScalarType_unknown &&
@@ -7811,7 +7811,7 @@ bool HLSLExternalSource::ValidatePrimitiveTypeForOperand(SourceLocation loc, Qua
   return isValid;
   return isValid;
 }
 }
 
 
-HRESULT HLSLExternalSource::CombineDimensions(QualType leftType, QualType rightType, QualType *resultType)
+HRESULT HLSLExternalSource::CombineDimensions(QualType leftType, QualType rightType, ArTypeObjectKind leftKind, ArTypeObjectKind rightKind, QualType *resultType)
 {
 {
   UINT leftRows, leftCols;
   UINT leftRows, leftCols;
   UINT rightRows, rightCols;
   UINT rightRows, rightCols;
@@ -7827,11 +7827,31 @@ HRESULT HLSLExternalSource::CombineDimensions(QualType leftType, QualType rightT
     *resultType = rightType;
     *resultType = rightType;
     return S_OK;
     return S_OK;
   } else if (leftRows <= rightRows && leftCols <= rightCols) {
   } else if (leftRows <= rightRows && leftCols <= rightCols) {
-    *resultType = leftType;
-    return S_OK;
+    DXASSERT_NOMSG((leftKind == AR_TOBJ_MATRIX || leftKind == AR_TOBJ_VECTOR) && 
+                   (rightKind == AR_TOBJ_MATRIX || rightKind == AR_TOBJ_VECTOR));
+    if (leftKind == rightKind) {
+      *resultType = leftType;
+      return S_OK;
+    } else {
+      // vector & matrix combination - only 1xN is allowed here
+      if (leftKind == AR_TOBJ_VECTOR && rightRows == 1) {
+        *resultType = leftType;
+        return S_OK;
+      }
+    }
   } else if (rightRows <= leftRows && rightCols <= leftCols) {
   } else if (rightRows <= leftRows && rightCols <= leftCols) {
-    *resultType = rightType;
-    return S_OK;
+    DXASSERT_NOMSG((leftKind == AR_TOBJ_MATRIX || leftKind == AR_TOBJ_VECTOR) && 
+                   (rightKind == AR_TOBJ_MATRIX || rightKind == AR_TOBJ_VECTOR));
+    if (leftKind == rightKind) {
+      *resultType = rightType;
+      return S_OK;
+    } else {
+      // matrix & vector combination - only 1xN is allowed here
+      if (rightKind == AR_TOBJ_VECTOR && leftRows == 1) {
+        *resultType = leftType;
+        return S_OK;
+      }
+    }
   } else if ( (1 == leftRows || 1 == leftCols) &&
   } else if ( (1 == leftRows || 1 == leftCols) &&
               (1 == rightRows || 1 == rightCols)) {
               (1 == rightRows || 1 == rightCols)) {
     // Handles cases where 1xN or Nx1 matrices are involved possibly mixed with vectors
     // Handles cases where 1xN or Nx1 matrices are involved possibly mixed with vectors
@@ -7843,6 +7863,11 @@ HRESULT HLSLExternalSource::CombineDimensions(QualType leftType, QualType rightT
       return S_OK;
       return S_OK;
     }
     }
   }
   }
+  else if (((leftKind == AR_TOBJ_VECTOR && rightKind == AR_TOBJ_MATRIX) ||
+            (leftKind == AR_TOBJ_MATRIX && rightKind == AR_TOBJ_VECTOR)) && leftTotal == rightTotal) {
+    *resultType = leftType;
+    return S_OK;
+  }
 
 
   return E_FAIL;
   return E_FAIL;
 }
 }
@@ -8032,7 +8057,7 @@ void HLSLExternalSource::CheckBinOpForHLSL(
       // Legal dimension combinations are identical, splat, and truncation.
       // Legal dimension combinations are identical, splat, and truncation.
       // ResultTy will be set to whichever type can be converted to, if legal,
       // ResultTy will be set to whichever type can be converted to, if legal,
       // with preference for leftType if both are possible.
       // with preference for leftType if both are possible.
-      if (FAILED(CombineDimensions(leftType, rightType, &ResultTy))) {
+      if (FAILED(CombineDimensions(leftType, rightType, leftObjectKind, rightObjectKind, &ResultTy))) {
         m_sema->Diag(OpLoc, diag::err_hlsl_type_mismatch);
         m_sema->Diag(OpLoc, diag::err_hlsl_type_mismatch);
         return;
         return;
       }
       }
@@ -8042,8 +8067,9 @@ void HLSLExternalSource::CheckBinOpForHLSL(
 
 
     // Here, element kind is combined with dimensions for computation type.
     // Here, element kind is combined with dimensions for computation type.
     UINT rowCount, colCount;
     UINT rowCount, colCount;
+    ArTypeObjectKind resultObjectKind = (leftObjectKind == rightObjectKind ? leftObjectKind : AR_TOBJ_INVALID);
     GetRowsAndColsForAny(ResultTy, rowCount, colCount);
     GetRowsAndColsForAny(ResultTy, rowCount, colCount);
-    ResultTy = NewSimpleAggregateType(AR_TOBJ_INVALID, resultElementKind, 0, rowCount, colCount)->getCanonicalTypeInternal();
+    ResultTy = NewSimpleAggregateType(resultObjectKind, resultElementKind, 0, rowCount, colCount)->getCanonicalTypeInternal();
   }
   }
 
 
   // Perform necessary conversion sequences for LHS and RHS
   // Perform necessary conversion sequences for LHS and RHS
@@ -8276,7 +8302,7 @@ clang::QualType HLSLExternalSource::CheckVectorConditional(
   }
   }
 
 
   // Combine LHS and RHS dimensions
   // Combine LHS and RHS dimensions
-  if (FAILED(CombineDimensions(leftType, rightType, &ResultTy))) {
+  if (FAILED(CombineDimensions(leftType, rightType, leftObjectKind, rightObjectKind, &ResultTy))) {
     m_sema->Diag(QuestionLoc, diag::err_hlsl_conditional_result_dimensions);
     m_sema->Diag(QuestionLoc, diag::err_hlsl_conditional_result_dimensions);
     return QualType();
     return QualType();
   }
   }

+ 55 - 0
tools/clang/test/CodeGenHLSL/quick-test/vector-matrix-binops.hlsl

@@ -0,0 +1,55 @@
+// RUN: %dxc -E main -T ps_6_0 %s  | FileCheck %s
+
+// CHECK: vector-matrix-binops.hlsl:24:14: warning: implicit truncation of vector type
+// CHECK: vector-matrix-binops.hlsl:31:16: warning: implicit truncation of vector type
+// CHECK: vector-matrix-binops.hlsl:36:24: error: type mismatch
+// CHECK: vector-matrix-binops.hlsl:37:27: error: type mismatch
+// CHECK: vector-matrix-binops.hlsl:39:14: warning: implicit truncation of vector type
+// CHECK: vector-matrix-binops.hlsl:52:27: error: type mismatch
+// CHECK: vector-matrix-binops.hlsl:53:27: error: type mismatch
+
+void main() {
+
+    float4 v4 = float4(0.1f, 0.2f, 0.3f, 0.4f);
+    float3 v3 = float3(0.1f, 0.2f, 0.3f);
+    float2 v2 = float2(0.5f, 0.6f);
+    float4x4 m44 = float4x4(v4, v4, v4, v4);
+    float2x2 m22 = float2x2(0.1f, 0.2f, 0.3f, 0.4f);
+    float1x4 m14 = float1x4(v4);
+    float3x2 m32 = float3x2(0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f);
+
+    // vector truncation
+    {
+      float2 res1 = v2 * v4; // expected-warning {{implicit truncation of vector type}} 
+      float2 res2 = v4 - v3; // expected-warning {{implicit truncation of vector type}} 
+    }
+
+    // matrix truncation
+    {
+      float1x4 res1 = m44 / m14; // expected-warning {{implicit truncation of vector type}} 
+      float1x4 res2 = m14 - m44; // expected-warning {{implicit truncation of vector type}} 
+      float2x2 res3 = m44 + m32; // expected-warning {{implicit truncation of vector type}} 
+    }
+
+    // matrix and vector binary operation - mismatched dimensions
+    {
+      float4 res1 = v4 * m44; // expected-error {{type mismatch}}
+      float4x4 res2 = m44 + v4; // expected-error {{type mismatch}}
+      float3 res3 = v3 * m14; // expected-warning {{implicit truncation of vector type}} 
+      float2 res4 = m14 / v2; // expected-warning {{implicit truncation of vector type}} 
+    }
+
+    // matrix and vector binary operation - matching dimensions - no warnings expected
+    {
+      float4 res1 = v4 / m22;
+      float2x2 res2 = m22 - v4;
+      float4 res3 = v4 + m14;
+    }
+    
+    // matrix mismatched dimensions
+    {
+      float2x3 m23 = float2x3(1, 2, 3, 4, 5, 6);
+      float3x2 res1 = m23 - m32; // expected-error {{type mismatch}}
+      float1x4 res2 = m14 / m23; // expected-error {{type mismatch}}
+    }
+}