Răsfoiți Sursa

Fix a crash converting between numerical and non-numerical types. (#2120)

Tristan Labelle 6 ani în urmă
părinte
comite
8368d0c952

+ 1 - 0
tools/clang/include/clang/AST/HlslTypes.h

@@ -382,6 +382,7 @@ bool IsHLSLLineStreamType(clang::QualType type);
 bool IsHLSLTriangleStreamType(clang::QualType type);
 bool IsHLSLStreamOutputType(clang::QualType type);
 bool IsHLSLResourceType(clang::QualType type);
+bool IsHLSLNumericOrAggregateOfNumericType(clang::QualType type);
 bool IsHLSLNumericUserDefinedType(clang::QualType type);
 bool IsHLSLAggregateType(clang::QualType type);
 clang::QualType GetHLSLResourceResultType(clang::QualType type);

+ 3 - 3
tools/clang/lib/AST/HlslTypes.cpp

@@ -91,14 +91,14 @@ bool IsHLSLVecType(clang::QualType type) {
   return false;
 }
 
-static bool IsHLSLNumeric(clang::QualType type) {
+bool IsHLSLNumericOrAggregateOfNumericType(clang::QualType type) {
   const clang::Type *Ty = type.getCanonicalType().getTypePtr();
   if (isa<RecordType>(Ty)) {
     if (IsHLSLVecMatType(type))
       return true;
     return IsHLSLNumericUserDefinedType(type);
   } else if (type->isArrayType()) {
-    return IsHLSLNumeric(QualType(type->getArrayElementTypeNoTypeQual(), 0));
+    return IsHLSLNumericOrAggregateOfNumericType(QualType(type->getArrayElementTypeNoTypeQual(), 0));
   }
   return Ty->isBuiltinType();
 }
@@ -117,7 +117,7 @@ bool IsHLSLNumericUserDefinedType(clang::QualType type) {
         name == "RaytracingAccelerationStructure")
       return false;
     for (auto member : RD->fields()) {
-      if (!IsHLSLNumeric(member->getType()))
+      if (!IsHLSLNumericOrAggregateOfNumericType(member->getType()))
         return false;
     }
     return true;

+ 2 - 1
tools/clang/lib/Sema/SemaHLSL.cpp

@@ -8153,7 +8153,8 @@ bool HLSLExternalSource::CanConvert(
         sourceSingleElementBuiltinType = hlsl::GetElementTypeOrType(source)->getAs<BuiltinType>();
       }
 
-      if (sourceSingleElementBuiltinType != nullptr) {
+      // We can only splat to target types that do not contain object/resource types
+      if (sourceSingleElementBuiltinType != nullptr && hlsl::IsHLSLNumericOrAggregateOfNumericType(target)) {
         BuiltinType::Kind kind = sourceSingleElementBuiltinType->getKind();
         switch (kind) {
         case BuiltinType::Kind::UInt:

+ 19 - 0
tools/clang/test/HLSL/conversions-non-numeric-aggregates.hlsl

@@ -0,0 +1,19 @@
+// RUN: %clang_cc1 -Wno-unused-value -fsyntax-only -ffreestanding -verify -verify-ignore-unexpected=note %s
+
+// Tests that conversions between numeric and non-numeric types/aggregates are disallowed.
+
+struct NumStruct { int a; };
+struct ObjStruct { Buffer a; };
+
+void main()
+{
+  (Buffer[1])0; /* expected-error {{cannot convert from 'literal int' to 'Buffer [1]'}} */
+  (ObjStruct)0; /* expected-error {{cannot convert from 'literal int' to 'ObjStruct'}} */
+  (Buffer[1])(int[1])0; /* expected-error {{cannot convert from 'int [1]' to 'Buffer [1]'}} */
+  (ObjStruct)(NumStruct)0; /* expected-error {{cannot convert from 'NumStruct' to 'ObjStruct'}} */
+
+  Buffer oa1[1];
+  ObjStruct os1;
+  (int)oa1; /* expected-error {{cannot convert from 'Buffer [1]' to 'int'}} */
+  (int)os1; /* expected-error {{cannot convert from 'ObjStruct' to 'int'}} */
+}

+ 5 - 0
tools/clang/unittests/HLSL/VerifierTest.cpp

@@ -42,6 +42,7 @@ public:
   TEST_METHOD(RunConstAssign)
   TEST_METHOD(RunConstDefault)
   TEST_METHOD(RunConversionsBetweenTypeShapes)
+  TEST_METHOD(RunConversionsNonNumericAggregates)
   TEST_METHOD(RunCppErrors)
   TEST_METHOD(RunCppErrorsHV2015)
   TEST_METHOD(RunCXX11Attributes)
@@ -171,6 +172,10 @@ TEST_F(VerifierTest, RunConversionsBetweenTypeShapes) {
   CheckVerifiesHLSL(L"conversions-between-type-shapes.hlsl");
 }
 
+TEST_F(VerifierTest, RunConversionsNonNumericAggregates) {
+  CheckVerifiesHLSL(L"conversions-non-numeric-aggregates.hlsl");
+}
+
 TEST_F(VerifierTest, RunCppErrors) {
   CheckVerifiesHLSL(L"cpp-errors.hlsl");
 }