Răsfoiți Sursa

Prevent texture of structs/matrices/arrays, and don't crash (#1744)

* Add a diagnostic error for non-scalar/vector texel types
* Harden semantic analysis logic against unexpected types
Tristan Labelle 6 ani în urmă
părinte
comite
7fcbdb1969

+ 1 - 0
lib/HLSL/HLOperationLower.cpp

@@ -3299,6 +3299,7 @@ void TranslateLoad(ResLoadHelper &helper, HLResource::Kind RK,
 
   Type *Ty = helper.retVal->getType();
   if (Ty->isPointerTy()) {
+    DXASSERT(!DxilResource::IsAnyTexture(RK), "Textures should not be treated as structured buffers.");
     TranslateStructBufSubscript(cast<CallInst>(helper.retVal), helper.handle,
                                 helper.status, OP, DL);
     return;

+ 16 - 2
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -2624,13 +2624,27 @@ bool CGMSHLSLRuntime::SetUAVSRV(SourceLocation loc,
   if (kind == hlsl::DxilResource::Kind::Texture2DMS ||
       kind == hlsl::DxilResource::Kind::Texture2DMSArray) {
     const ClassTemplateSpecializationDecl *templateDecl =
-        dyn_cast<ClassTemplateSpecializationDecl>(RD);
+        cast<ClassTemplateSpecializationDecl>(RD);
     const clang::TemplateArgument &sampleCountArg =
         templateDecl->getTemplateArgs()[1];
     uint32_t sampleCount = sampleCountArg.getAsIntegral().getLimitedValue();
     hlslRes->SetSampleCount(sampleCount);
   }
 
+  if (hlsl::DxilResource::IsAnyTexture(kind)) {
+    const ClassTemplateSpecializationDecl *templateDecl = cast<ClassTemplateSpecializationDecl>(RD);
+    const clang::TemplateArgument &texelTyArg = templateDecl->getTemplateArgs()[0];
+    llvm::Type *texelTy = CGM.getTypes().ConvertType(texelTyArg.getAsType());
+    if (!texelTy->isFloatingPointTy() && !texelTy->isIntegerTy()
+      && !hlsl::IsHLSLVecType(texelTyArg.getAsType())) {
+      DiagnosticsEngine &Diags = CGM.getDiags();
+      unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error,
+        "texture resource texel type must be scalar or vector");
+      Diags.Report(loc, DiagID);
+      return false;
+    }
+  }
+
   if (kind != hlsl::DxilResource::Kind::StructuredBuffer) {
     QualType Ty = resultTy;
     QualType EltTy = Ty;
@@ -2697,7 +2711,7 @@ bool CGMSHLSLRuntime::SetUAVSRV(SourceLocation loc,
   if (kind == hlsl::DxilResource::Kind::TypedBuffer ||
       kind == hlsl::DxilResource::Kind::StructuredBuffer) {
     const ClassTemplateSpecializationDecl *templateDecl =
-        dyn_cast<ClassTemplateSpecializationDecl>(RD);
+        cast<ClassTemplateSpecializationDecl>(RD);
 
     const clang::TemplateArgument &retTyArg =
         templateDecl->getTemplateArgs()[0];

+ 22 - 17
tools/clang/lib/Sema/SemaHLSL.cpp

@@ -1573,24 +1573,26 @@ const char* g_ArBasicTypeNames[] =
 
 C_ASSERT(_countof(g_ArBasicTypeNames) == AR_BASIC_MAXIMUM_COUNT);
 
+static bool IsValidBasicKind(ArBasicKind kind) {
+  return kind != AR_BASIC_COUNT &&
+    kind != AR_BASIC_NONE &&
+    kind != AR_BASIC_UNKNOWN &&
+    kind != AR_BASIC_NOCAST &&
+    kind != AR_BASIC_POINTER &&
+    kind != AR_OBJECT_RENDERTARGETVIEW &&
+    kind != AR_OBJECT_DEPTHSTENCILVIEW &&
+    kind != AR_OBJECT_COMPUTESHADER &&
+    kind != AR_OBJECT_DOMAINSHADER &&
+    kind != AR_OBJECT_GEOMETRYSHADER &&
+    kind != AR_OBJECT_HULLSHADER &&
+    kind != AR_OBJECT_PIXELSHADER &&
+    kind != AR_OBJECT_VERTEXSHADER &&
+    kind != AR_OBJECT_PIXELFRAGMENT &&
+    kind != AR_OBJECT_VERTEXFRAGMENT;
+}
 // kind should never be a flag value or effects framework type - we simply do not expect to deal with these
 #define DXASSERT_VALIDBASICKIND(kind) \
-  DXASSERT(\
-  kind != AR_BASIC_COUNT && \
-  kind != AR_BASIC_NONE && \
-  kind != AR_BASIC_UNKNOWN && \
-  kind != AR_BASIC_NOCAST && \
-  kind != AR_BASIC_POINTER && \
-  kind != AR_OBJECT_RENDERTARGETVIEW && \
-  kind != AR_OBJECT_DEPTHSTENCILVIEW && \
-  kind != AR_OBJECT_COMPUTESHADER && \
-  kind != AR_OBJECT_DOMAINSHADER && \
-  kind != AR_OBJECT_GEOMETRYSHADER && \
-  kind != AR_OBJECT_HULLSHADER && \
-  kind != AR_OBJECT_PIXELSHADER && \
-  kind != AR_OBJECT_VERTEXSHADER && \
-  kind != AR_OBJECT_PIXELFRAGMENT && \
-  kind != AR_OBJECT_VERTEXFRAGMENT, "otherwise caller is using a special flag or an unsupported kind value");
+  DXASSERT(IsValidBasicKind(kind), "otherwise caller is using a special flag or an unsupported kind value");
 
 static
 const char* g_DeprecatedEffectObjectNames[] =
@@ -5582,7 +5584,10 @@ bool HLSLExternalSource::MatchArguments(
           return false;
         }
         pEltType = GetTypeElementKind(objectElement);
-        DXASSERT_VALIDBASICKIND(pEltType);
+        if (!IsValidBasicKind(pEltType)) {
+          // This can happen with Texture2D<Struct> or other invalid declarations
+          return false;
+        }
       }
       else {
         pEltType = ComponentType[pArgument->uComponentTypeId];

+ 13 - 0
tools/clang/test/CodeGenHLSL/quick-test/texture_of_array_error.hlsl

@@ -0,0 +1,13 @@
+// RUN: %dxc -T ps_6_0 -E main %s | FileCheck %s
+// CHECK: error: texture resource texel type must be scalar or vector
+typedef float a[4];
+Texture2D<a> t;
+RWTexture2D<a> rwt;
+SamplerState s;
+float main(float2 f2 : F2, int2 i2 : I) : SV_TARGET
+{
+    // Ensure semantic analysis doesn't crash
+    rwt[i2] = t.Load(int3(i2, 0));
+    t.Gather(s, f2, i2); // Test template resolution with INTRIN_COMPTYPE_FROM_TYPE_ELT0
+    return 0;
+}

+ 13 - 0
tools/clang/test/CodeGenHLSL/quick-test/texture_of_matrix_error.hlsl

@@ -0,0 +1,13 @@
+// RUN: %dxc -T ps_6_0 -E main %s | FileCheck %s
+// Note: FXC accepts this
+// CHECK: error: texture resource texel type must be scalar or vector
+Texture2D<float1x1> t;
+RWTexture2D<float1x1> rwt;
+SamplerState s;
+float main(float2 f2 : F2, int2 i2 : I) : SV_TARGET
+{
+    // Ensure semantic analysis doesn't crash
+    rwt[i2] = t.Load(int3(i2, 0));
+    t.Gather(s, f2, i2); // Test template resolution with INTRIN_COMPTYPE_FROM_TYPE_ELT0
+    return 0;
+}

+ 13 - 0
tools/clang/test/CodeGenHLSL/quick-test/texture_of_struct_error.hlsl

@@ -0,0 +1,13 @@
+// RUN: %dxc -T ps_6_0 -E main %s | FileCheck %s
+// CHECK: error: texture resource texel type must be scalar or vector
+struct Struct { float f; };
+Texture2D<Struct> t;
+RWTexture2D<Struct> rwt;
+SamplerState s;
+float main(float2 f2 : F2, int2 i2 : I) : SV_TARGET
+{
+    // Ensure semantic analysis doesn't crash
+    rwt[i2] = t.Load(int3(i2, 0));
+    t.Gather(s, f2, i2); // Test template resolution with INTRIN_COMPTYPE_FROM_TYPE_ELT0
+    return 0;
+}