Jelajahi Sumber

Return error on subscript access of index vector (#2837)

When an output index array vector is accessed by element instead of the
whole, DXIL generation must produce an error. Previously, when the
access was performed using the dot operator, the error was produced, but
the processing wasn't terminated, so a second error erroneously saying
that the vector had no member by that name was also produced by falling
through to the struct processing. When the index was performed using a
numerical subscript, no error was produced at all and a crash resulted
when the indices were processed later.

By changing the interface for functions that process records that are
really HLSL Matrix, Vector or Array types, inapplicable structs can be
distinguished from expressions that result in an error and the error can
be returned, preventing the duplicate error.

A check for use of the output indices is performed in the subscript
processing to produce an error in this case.

As part of this, an incidental change was made to short-circuit
processing when an array is found to have a member access that isn't
Length.

Tests are added for all the above.

A very incidental change was made to GetStructuralForm() that was
applied to a duplicate function, but never propagated correctly.
Greg Roth 5 tahun lalu
induk
melakukan
a841ddde5a

+ 15 - 8
tools/clang/include/clang/Sema/SemaHLSL.h

@@ -123,32 +123,38 @@ bool IsConversionToLessOrEqualElements(
   const clang::QualType& targetType,
   bool explicitConversion);
 
-bool LookupMatrixMemberExprForHLSL(
+clang::ExprResult LookupMatrixMemberExprForHLSL(
   clang::Sema* self,
   clang::Expr& BaseExpr,
   clang::DeclarationName MemberName,
   bool IsArrow,
   clang::SourceLocation OpLoc,
-  clang::SourceLocation MemberLoc,
-  _Inout_ clang::ExprResult* result);
+  clang::SourceLocation MemberLoc);
 
-bool LookupVectorMemberExprForHLSL(
+clang::ExprResult LookupVectorMemberExprForHLSL(
   clang::Sema* self,
   clang::Expr& BaseExpr,
   clang::DeclarationName MemberName,
   bool IsArrow,
   clang::SourceLocation OpLoc,
-  clang::SourceLocation MemberLoc,
-  _Inout_ clang::ExprResult* result);
+  clang::SourceLocation MemberLoc);
+
+clang::ExprResult LookupArrayMemberExprForHLSL(
+  clang::Sema* self,
+  clang::Expr& BaseExpr,
+  clang::DeclarationName MemberName,
+  bool IsArrow,
+  clang::SourceLocation OpLoc,
+  clang::SourceLocation MemberLoc);
 
-bool LookupArrayMemberExprForHLSL(
+bool LookupRecordMemberExprForHLSL(
   clang::Sema* self,
   clang::Expr& BaseExpr,
   clang::DeclarationName MemberName,
   bool IsArrow,
   clang::SourceLocation OpLoc,
   clang::SourceLocation MemberLoc,
-  _Inout_ clang::ExprResult* result);
+  clang::ExprResult &result);
 
 clang::ExprResult MaybeConvertScalarToVector(
   _In_ clang::Sema* Self,
@@ -251,6 +257,7 @@ clang::QualType CheckVectorConditional(
 }
 
 bool IsTypeNumeric(_In_ clang::Sema* self, _In_ clang::QualType &type);
+bool IsExprAccessingOutIndicesArray(clang::Expr* BaseExpr);
 
 // This function reads the given declaration TSS and returns the corresponding parsedType with the
 // corresponding type. Replaces the given parsed type with the new type

+ 2 - 1
tools/clang/lib/AST/HlslTypes.cpp

@@ -258,7 +258,8 @@ QualType GetStructuralForm(QualType type) {
     type = RefType ? RefType->getPointeeType() : AttrType->getEquivalentType();
   }
 
-  return type->getCanonicalTypeUnqualified();
+  // Despite its name, getCanonicalTypeUnqualified will preserve const for array elements or something
+  return QualType(type->getCanonicalTypeUnqualified()->getTypePtr(), 0);
 }
 
 uint32_t GetElementCount(clang::QualType type) {

+ 10 - 0
tools/clang/lib/Sema/SemaExpr.cpp

@@ -4117,6 +4117,16 @@ Sema::ActOnArraySubscriptExpr(Scope *S, Expr *base, SourceLocation lbLoc,
     idx = result.get();
   }
 
+  // HLSL Change Starts - Check for subscript access of out indices
+  // Disallow component access for out indices for DXIL path. We still allow
+  // this in SPIR-V path.
+  if (getLangOpts().HLSL && !getLangOpts().SPIRV &&
+      base->getType()->isRecordType() && IsExprAccessingOutIndicesArray(base)) {
+    Diag(lbLoc, diag::err_hlsl_out_indices_array_incorrect_access);
+    return ExprError();
+  }
+  // HLSL Change Ends
+
   // Build an unanalyzed expression if either operand is type-dependent.
   if (getLangOpts().CPlusPlus &&
       (base->isTypeDependent() || idx->isTypeDependent())) {

+ 5 - 20
tools/clang/lib/Sema/SemaExprMember.cpp

@@ -1234,26 +1234,11 @@ static ExprResult LookupMemberExpr(Sema &S, LookupResult &R,
   }
 
   // HLSL Change Starts
-  {
-    ExprResult matrixResult;
-    if (S.getLangOpts().HLSL &&
-      hlsl::LookupMatrixMemberExprForHLSL(&S, *BaseExpr.get(), MemberName, IsArrow, OpLoc, MemberLoc, &matrixResult)) {
-      return matrixResult;
-    }
-  }
-  {
-    ExprResult vectorResult;
-    if (S.getLangOpts().HLSL &&
-      hlsl::LookupVectorMemberExprForHLSL(&S, *BaseExpr.get(), MemberName, IsArrow, OpLoc, MemberLoc, &vectorResult)) {
-      return vectorResult;
-    }
-  }
-  {
-    ExprResult arrayResult;
-    if (S.getLangOpts().HLSL &&
-      hlsl::LookupArrayMemberExprForHLSL(&S, *BaseExpr.get(), MemberName, IsArrow, OpLoc, MemberLoc, &arrayResult)) {
-      return arrayResult;
-    }
+  // Look up HLSL specialty records: Matrix, Vector, Array
+  if (S.getLangOpts().HLSL) {
+    ExprResult res;
+    if (hlsl::LookupRecordMemberExprForHLSL(&S, *BaseExpr.get(), MemberName, IsArrow, OpLoc, MemberLoc, res))
+      return res;
   }
   // HLSL Change Ends
 

+ 66 - 77
tools/clang/lib/Sema/SemaHLSL.cpp

@@ -4603,15 +4603,13 @@ public:
   /// <param name="IsArrow">Whether access is through arrow (a->b) rather than period (a.b).</param>
   /// <param name="OpLoc">Location of access operand.</param>
   /// <param name="MemberLoc">Location of member.</param>
-  /// <param name="result">Result of lookup operation.</param>
-  /// <returns>true if the base type is a matrix and the lookup has been handled.</returns>
-  bool LookupMatrixMemberExprForHLSL(
+  /// <returns>Result of lookup operation.</returns>
+  ExprResult LookupMatrixMemberExprForHLSL(
     Expr& BaseExpr,
     DeclarationName MemberName,
     bool IsArrow,
     SourceLocation OpLoc,
-    SourceLocation MemberLoc,
-    ExprResult* result);
+    SourceLocation MemberLoc);
 
   /// <summary>Performs a member lookup on the specified BaseExpr if it's a vector.</summary>
   /// <param name="BaseExpr">Base expression for member access.</param>
@@ -4619,15 +4617,13 @@ public:
   /// <param name="IsArrow">Whether access is through arrow (a->b) rather than period (a.b).</param>
   /// <param name="OpLoc">Location of access operand.</param>
   /// <param name="MemberLoc">Location of member.</param>
-  /// <param name="result">Result of lookup operation.</param>
-  /// <returns>true if the base type is a vector and the lookup has been handled.</returns>
-  bool LookupVectorMemberExprForHLSL(
+  /// <returns>Result of lookup operation.</returns>
+  ExprResult LookupVectorMemberExprForHLSL(
     Expr& BaseExpr,
     DeclarationName MemberName,
     bool IsArrow,
     SourceLocation OpLoc,
-    SourceLocation MemberLoc,
-    ExprResult* result);
+    SourceLocation MemberLoc);
 
   /// <summary>Performs a member lookup on the specified BaseExpr if it's an array.</summary>
   /// <param name="BaseExpr">Base expression for member access.</param>
@@ -4635,15 +4631,13 @@ public:
   /// <param name="IsArrow">Whether access is through arrow (a->b) rather than period (a.b).</param>
   /// <param name="OpLoc">Location of access operand.</param>
   /// <param name="MemberLoc">Location of member.</param>
-  /// <param name="result">Result of lookup operation.</param>
-  /// <returns>true if the base type is an array and the lookup has been handled.</returns>
-  bool LookupArrayMemberExprForHLSL(
+  /// <returns>Result of lookup operation.</returns>
+  ExprResult LookupArrayMemberExprForHLSL(
     Expr& BaseExpr,
     DeclarationName MemberName,
     bool IsArrow,
     SourceLocation OpLoc,
-    SourceLocation MemberLoc,
-    ExprResult* result);
+    SourceLocation MemberLoc);
 
   /// <summary>If E is a scalar, converts it to a 1-element vector.</summary>
   /// <param name="E">Expression to convert.</param>
@@ -7260,26 +7254,16 @@ MatrixMemberAccessError TryParseMatrixMemberAccess(_In_z_ const char* memberText
   return MatrixMemberAccessError_None;
 }
 
-bool HLSLExternalSource::LookupMatrixMemberExprForHLSL(
+ExprResult HLSLExternalSource::LookupMatrixMemberExprForHLSL(
   Expr& BaseExpr,
   DeclarationName MemberName,
   bool IsArrow,
   SourceLocation OpLoc,
-  SourceLocation MemberLoc,
-  ExprResult* result)
+  SourceLocation MemberLoc)
 {
-  DXASSERT_NOMSG(result != nullptr);
-
   QualType BaseType = BaseExpr.getType();
   DXASSERT(!BaseType.isNull(), "otherwise caller should have stopped analysis much earlier");
-
-  // Assume failure.
-  *result = ExprError();
-
-  if (GetTypeObjectKind(BaseType) != AR_TOBJ_MATRIX)
-  {
-    return false;
-  }
+  DXASSERT(GetTypeObjectKind(BaseType) == AR_TOBJ_MATRIX, "Should only be called on known matrix types");
 
   QualType elementType;
   UINT rowCount, colCount;
@@ -7341,7 +7325,7 @@ bool HLSLExternalSource::LookupMatrixMemberExprForHLSL(
     // processing.
     if (!positions.IsValid)
     {
-      return true;
+      return ExprError();
     }
   }
 
@@ -7361,9 +7345,8 @@ bool HLSLExternalSource::LookupMatrixMemberExprForHLSL(
     positions.ContainsDuplicateElements() ? VK_RValue :
       (IsArrow ? VK_LValue : BaseExpr.getValueKind());
   ExtMatrixElementExpr* matrixExpr = new (m_context)ExtMatrixElementExpr(resultType, VK, &BaseExpr, *member, MemberLoc, positions);
-  *result = matrixExpr;
 
-  return true;
+  return matrixExpr;
 }
 
 enum VectorMemberAccessError {
@@ -7463,7 +7446,7 @@ VectorMemberAccessError TryParseVectorMemberAccess(_In_z_ const char* memberText
   return VectorMemberAccessError_None;
 }
 
-static bool IsExprAccessingOutIndicesArray(Expr* BaseExpr) {
+bool IsExprAccessingOutIndicesArray(Expr* BaseExpr) {
   switch(BaseExpr->getStmtClass()) {
   case Stmt::ArraySubscriptExprClass: {
     ArraySubscriptExpr* ase = cast<ArraySubscriptExpr>(BaseExpr);
@@ -7486,24 +7469,15 @@ static bool IsExprAccessingOutIndicesArray(Expr* BaseExpr) {
   }
 }
 
-bool HLSLExternalSource::LookupVectorMemberExprForHLSL(
+ExprResult HLSLExternalSource::LookupVectorMemberExprForHLSL(
     Expr& BaseExpr,
     DeclarationName MemberName,
     bool IsArrow,
     SourceLocation OpLoc,
-    SourceLocation MemberLoc,
-    ExprResult* result) {
-  DXASSERT_NOMSG(result != nullptr);
-
+    SourceLocation MemberLoc) {
   QualType BaseType = BaseExpr.getType();
   DXASSERT(!BaseType.isNull(), "otherwise caller should have stopped analysis much earlier");
-
-  // Assume failure.
-  *result = ExprError();
-
-  if (GetTypeObjectKind(BaseType) != AR_TOBJ_VECTOR) {
-    return false;
-  }
+  DXASSERT(GetTypeObjectKind(BaseType) == AR_TOBJ_VECTOR, "Should only be called on known vector types");
 
   QualType elementType;
   UINT colCount = GetHLSLVecSize(BaseType);
@@ -7553,7 +7527,7 @@ bool HLSLExternalSource::LookupVectorMemberExprForHLSL(
     // generate the member access expression with the correct arity and continue
     // processing.
     if (!positions.IsValid) {
-      return true;
+      return ExprError();
     }
   }
 
@@ -7561,10 +7535,10 @@ bool HLSLExternalSource::LookupVectorMemberExprForHLSL(
 
   // Disallow component access for out indices for DXIL path. We still allow
   // this in SPIR-V path.
-  if (!getSema()->getLangOpts().SPIRV &&
+  if (!m_sema->getLangOpts().SPIRV &&
       IsExprAccessingOutIndicesArray(&BaseExpr) && positions.Count < colCount) {
     m_sema->Diag(MemberLoc, diag::err_hlsl_out_indices_array_incorrect_access);
-    return false;
+    return ExprError();
   }
 
   // Consume elements
@@ -7581,30 +7555,20 @@ bool HLSLExternalSource::LookupVectorMemberExprForHLSL(
     positions.ContainsDuplicateElements() ? VK_RValue :
       (IsArrow ? VK_LValue : BaseExpr.getValueKind());
   HLSLVectorElementExpr* vectorExpr = new (m_context)HLSLVectorElementExpr(resultType, VK, &BaseExpr, *member, MemberLoc, positions);
-  *result = vectorExpr;
 
-  return true;
+  return vectorExpr;
 }
 
-bool HLSLExternalSource::LookupArrayMemberExprForHLSL(
+ExprResult HLSLExternalSource::LookupArrayMemberExprForHLSL(
   Expr& BaseExpr,
   DeclarationName MemberName,
   bool IsArrow,
   SourceLocation OpLoc,
-  SourceLocation MemberLoc,
-  ExprResult* result) {
-
-  DXASSERT_NOMSG(result != nullptr);
+  SourceLocation MemberLoc) {
 
   QualType BaseType = BaseExpr.getType();
   DXASSERT(!BaseType.isNull(), "otherwise caller should have stopped analysis much earlier");
-
-  // Assume failure.
-  *result = ExprError();
-
-  if (GetTypeObjectKind(BaseType) != AR_TOBJ_ARRAY) {
-    return false;
-  }
+  DXASSERT(GetTypeObjectKind(BaseType) == AR_TOBJ_ARRAY, "Should only be called on known array types");
 
   IdentifierInfo *member = MemberName.getAsIdentifierInfo();
   const char *memberText = member->getNameStart();
@@ -7616,7 +7580,7 @@ bool HLSLExternalSource::LookupArrayMemberExprForHLSL(
       unsigned hlslVer = getSema()->getLangOpts().HLSLVersion;
       if (hlslVer > 2016) {
         m_sema->Diag(MemberLoc, diag::err_hlsl_unsupported_for_version_lower) << "Length" << "2016";
-        return false;
+        return ExprError();
       }
       if (hlslVer == 2016) {
         m_sema->Diag(MemberLoc, diag::warn_deprecated) << "Length";
@@ -7625,11 +7589,13 @@ bool HLSLExternalSource::LookupArrayMemberExprForHLSL(
       UnaryExprOrTypeTraitExpr *arrayLenExpr = new (m_context) UnaryExprOrTypeTraitExpr(
         UETT_ArrayLength, &BaseExpr, m_context->getSizeType(), MemberLoc, BaseExpr.getSourceRange().getEnd());
 
-      *result = arrayLenExpr;
-      return true;
+      return arrayLenExpr;
     }
   }
-  return false;
+  m_sema->Diag(MemberLoc, diag::err_typecheck_member_reference_struct_union)
+    << BaseType << BaseExpr.getSourceRange() << MemberLoc;
+
+  return ExprError();
 }
   
 
@@ -10083,43 +10049,66 @@ bool hlsl::IsConversionToLessOrEqualElements(
     ->IsConversionToLessOrEqualElements(sourceExpr, targetType, explicitConversion);
 }
 
-bool hlsl::LookupMatrixMemberExprForHLSL(
+ExprResult hlsl::LookupMatrixMemberExprForHLSL(
   Sema* self,
   Expr& BaseExpr,
   DeclarationName MemberName,
   bool IsArrow,
   SourceLocation OpLoc,
-  SourceLocation MemberLoc,
-  ExprResult* result)
+  SourceLocation MemberLoc)
 {
   return HLSLExternalSource::FromSema(self)
-    ->LookupMatrixMemberExprForHLSL(BaseExpr, MemberName, IsArrow, OpLoc, MemberLoc, result);
+    ->LookupMatrixMemberExprForHLSL(BaseExpr, MemberName, IsArrow, OpLoc, MemberLoc);
 }
 
-bool hlsl::LookupVectorMemberExprForHLSL(
+ExprResult hlsl::LookupVectorMemberExprForHLSL(
   Sema* self,
   Expr& BaseExpr,
   DeclarationName MemberName,
   bool IsArrow,
   SourceLocation OpLoc,
-  SourceLocation MemberLoc,
-  ExprResult* result)
+  SourceLocation MemberLoc)
 {
   return HLSLExternalSource::FromSema(self)
-    ->LookupVectorMemberExprForHLSL(BaseExpr, MemberName, IsArrow, OpLoc, MemberLoc, result);
+    ->LookupMatrixMemberExprForHLSL(BaseExpr, MemberName, IsArrow, OpLoc, MemberLoc);
 }
 
-bool hlsl::LookupArrayMemberExprForHLSL(
+ExprResult hlsl::LookupArrayMemberExprForHLSL(
   Sema* self,
   Expr& BaseExpr,
   DeclarationName MemberName,
   bool IsArrow,
   SourceLocation OpLoc,
-  SourceLocation MemberLoc,
-  ExprResult* result)
+  SourceLocation MemberLoc)
 {
   return HLSLExternalSource::FromSema(self)
-    ->LookupArrayMemberExprForHLSL(BaseExpr, MemberName, IsArrow, OpLoc, MemberLoc, result);
+    ->LookupMatrixMemberExprForHLSL(BaseExpr, MemberName, IsArrow, OpLoc, MemberLoc);
+}
+
+bool hlsl::LookupRecordMemberExprForHLSL(
+  Sema* self,
+  Expr& BaseExpr,
+  DeclarationName MemberName,
+  bool IsArrow,
+  SourceLocation OpLoc,
+  SourceLocation MemberLoc,
+  ExprResult &result)
+{
+  HLSLExternalSource *source = HLSLExternalSource::FromSema(self);
+  switch (source->GetTypeObjectKind(BaseExpr.getType())) {
+  case AR_TOBJ_MATRIX:
+    result = source->LookupMatrixMemberExprForHLSL(BaseExpr, MemberName, IsArrow, OpLoc, MemberLoc);
+    return true;
+  case AR_TOBJ_VECTOR:
+    result = source->LookupVectorMemberExprForHLSL(BaseExpr, MemberName, IsArrow, OpLoc, MemberLoc);
+    return true;
+  case AR_TOBJ_ARRAY:
+    result = source->LookupArrayMemberExprForHLSL(BaseExpr, MemberName, IsArrow, OpLoc, MemberLoc);
+    return true;
+  default:
+    return false;
+  }
+  return false;
 }
 
 clang::ExprResult hlsl::MaybeConvertScalarToVector(

+ 2 - 5
tools/clang/lib/Sema/TreeTransform.h

@@ -2006,9 +2006,7 @@ public:
     ExprResult result;
     DeclarationName Name(&Accessor);
 
-    hlsl::LookupMatrixMemberExprForHLSL(&getSema(), *Base, Name, IsArrowFalse, OpLoc, AccessorLoc, &result);
-
-    return result;
+    return hlsl::LookupMatrixMemberExprForHLSL(&getSema(), *Base, Name, IsArrowFalse, OpLoc, AccessorLoc);
   }
 
   /// \brief Build a new extended vector element access expression.
@@ -2024,9 +2022,8 @@ public:
     ExprResult result;
     DeclarationName Name(&Accessor);
 
-    hlsl::LookupVectorMemberExprForHLSL(&getSema(), *Base, Name, IsArrowFalse, OpLoc, AccessorLoc, &result);
+    return hlsl::LookupVectorMemberExprForHLSL(&getSema(), *Base, Name, IsArrowFalse, OpLoc, AccessorLoc);
 
-    return result;
   }
 
   // HLSL Changes End

+ 15 - 0
tools/clang/test/HLSLFileCheck/hlsl/functions/arguments/arrayArgErr.hlsl

@@ -0,0 +1,15 @@
+// RUN: %dxc -E main -T cs_6_0 %s | FileCheck %s
+
+// CHECK: error: Length is only allowed for HLSL 2016 and lower
+// CHECK: error: member reference base type 'float [2]' is not a structure or union
+// CHECK-NOT: error:
+int fn_int_arr(float arr[2]) {
+  return arr.Length + arr.x;
+}
+
+[numthreads(8,8,1)]
+void main() {
+
+  float arr2[2] = { 1, 2 };
+  int m = fn_int_arr(arr2);
+}

+ 4 - 0
tools/clang/test/HLSLFileCheck/shader_targets/mesh/illegalOutIndicesAssignment.hlsl

@@ -1,6 +1,9 @@
 // RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
 
+// One error for each access method and no additionals
 // CHECK: error: a vector in out indices array must be accessed as a whole
+// CHECK: error: a vector in out indices array must be accessed as a whole
+// CHECK-NOT: error:
 
 #define MAX_VERT 32
 #define MAX_PRIM 16
@@ -51,6 +54,7 @@ void main(
     if (tig % 3) {
       primIndices[tig / 3] = uint3(tig, tig + 1, tig + 2);
       primIndices[tig / 3].x = 0;
+      primIndices[tig / 3][1] = 0;
       MeshPerPrimitive op;
       op.normal = mpl.normal;
       op.malnor = mpl.malnor;