浏览代码

Handle scalar args to out vector params (#3246)

Previously, trying to pass a scalar to a vector out parameter would
cause an assert and no truncation warning. This scales back the assert
and adds the missing warning.

Trying to pass a scalar to an inout parameter would cause a crash. This
allows for the necessary splat and avoids the erroneous attempts to
create a cast that leads to the crash.

Finally, as an incidental, this adds output parameter information to an
error that ostensibly required it, but never had it.
Greg Roth 4 年之前
父节点
当前提交
35fda6914e

+ 16 - 6
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -25,6 +25,7 @@
 #include "clang/AST/HlslTypes.h"
 #include "clang/Frontend/CodeGenOptions.h"
 #include "clang/Lex/HLSLMacroExpander.h"
+#include "clang/Sema/SemaDiagnostic.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/StringSwitch.h"
 #include "llvm/ADT/StringSet.h"
@@ -5668,6 +5669,12 @@ void CGMSHLSLRuntime::EmitHLSLOutParamConversionInit(
     if (Param->isModifierOut()) {
       castArgList.emplace_back(tmpLV);
       castArgList.emplace_back(argLV);
+      if (isVector && !hlsl::IsHLSLVecType(argType)) {
+        // This assumes only implicit casts because explicit casts can only produce RValues
+        // currently and out parameters are LValues.
+        DiagnosticsEngine &Diags = CGM.getDiags();
+        Diags.Report(Param->getLocation(), diag::warn_hlsl_implicit_vector_truncation);
+      }
     }
 
     // cast before the call
@@ -5691,9 +5698,14 @@ void CGMSHLSLRuntime::EmitHLSLOutParamConversionInit(
           EmitHLSLMatrixStore(CGF, castVal, tmpArgAddr, ParamTy);
         }
         else {
-          Value *castVal = ConvertScalarOrVector(CGF, outVal, argType, ParamTy);
-          castVal = CGF.EmitToMemory(castVal, ParamTy);
-          CGF.Builder.CreateStore(castVal, tmpArgAddr);
+          if (outVal->getType()->isVectorTy()) {
+            Value *castVal = ConvertScalarOrVector(CGF, outVal, argType, ParamTy);
+            castVal = CGF.EmitToMemory(castVal, ParamTy);
+            CGF.Builder.CreateStore(castVal, tmpArgAddr);
+          } else {
+            // This allows for splatting, unlike the above.
+            SimpleFlatValCopy(CGF, outVal, argType, tmpArgAddr, ParamTy);
+          }
         }
       } else {
         DXASSERT(argAddr, "should be RV or simple LV");
@@ -5739,9 +5751,7 @@ void CGMSHLSLRuntime::EmitHLSLOutParamConversionCopyBack(
           // Don't need cast.
         } else if (ToTy->getScalarType() == FromTy->getScalarType()) {
           if (ToTy->getScalarType() == ToTy) {
-            DXASSERT(FromTy->isVectorTy() &&
-                         FromTy->getVectorNumElements() == 1,
-                     "must be vector of 1 element");
+            DXASSERT(FromTy->isVectorTy(), "must be vector");
             castVal = CGF.Builder.CreateExtractElement(outVal, (uint64_t)0);
           } else {
             DXASSERT(!FromTy->isVectorTy(), "must be scalar type");

+ 7 - 1
tools/clang/lib/Sema/SemaHLSL.cpp

@@ -9674,7 +9674,6 @@ bool HLSLExternalSource::ValidateCast(
 
   if (!CanConvert(OpLoc, sourceExpr, target, explicitConversion, &remarks, standard))
   {
-    const bool IsOutputParameter = false;
 
     //
     // Check whether the lack of explicit-ness matters.
@@ -9695,6 +9694,13 @@ bool HLSLExternalSource::ValidateCast(
 
     if (!suppressErrors)
     {
+      bool IsOutputParameter = false;
+      if (clang::DeclRefExpr *OutFrom = dyn_cast<clang::DeclRefExpr>(sourceExpr)) {
+        if (ParmVarDecl *Param = dyn_cast<ParmVarDecl>(OutFrom->getDecl())) {
+          IsOutputParameter = Param->isModifierOut();
+        }
+      }
+
       m_sema->Diag(OpLoc, diag::err_hlsl_cannot_convert)
         << explicitForDiagnostics << IsOutputParameter << source << target;
     }

+ 72 - 0
tools/clang/test/HLSLFileCheck/hlsl/functions/arguments/inout_trunc.hlsl

@@ -0,0 +1,72 @@
+// RUN: %dxc -E NocrashMain -T ps_6_0 %s | FileCheck %s -check-prefix=CHK_NOCRASH
+// RUN: %dxc -E WarnMain -T ps_6_0 %s | FileCheck %s -check-prefix=CHK_WARN
+
+// Test that no crashes result when a scalar is provided to an outvar
+// and that the new warning is produced.
+
+// CHK_WARN: warning: implicit truncation of vector type
+// CHK_WARN: warning: implicit truncation of vector type
+// CHK_WARN: warning: implicit truncation of vector type
+// CHK_WARN: warning: implicit truncation of vector type
+// CHK_WARN: warning: implicit truncation of vector type
+// CHK_WARN: warning: implicit truncation of vector type
+// CHK_WARN-NOT: warning: implicit truncation of vector type
+// CHK_NOCRASH: NocrashMain
+
+float val1;
+float val2;
+float val3;
+
+float2 vec2;
+float3 vec3;
+float4 vec4;
+
+void TakeItOut(out float2 it) {
+  it = val1;
+}
+
+void TakeItIn(inout float3 it) {
+  it = val2;
+}
+
+void TakeItIn2(inout float4 it) {
+  it += val3;
+}
+
+void TakeEmOut(out float2 em) {
+  em = vec2;
+}
+
+void TakeEmIn(inout float3 em) {
+  em = vec3;
+}
+
+void TakeEmIn2(inout float4 em) {
+  em += vec4;
+}
+
+
+float2 RunTest(float it, float em)
+{
+  float c = it;
+  TakeItOut(it);
+  TakeItIn(it);
+  TakeItIn2(it);
+
+  TakeEmOut(em);
+  TakeEmIn(em);
+  TakeEmIn2(em);
+  return float2(it, em);
+}
+
+float2 NocrashMain(float it: A, float em: B) : SV_Target
+{
+  return RunTest(it, em);
+}
+
+// Missing out semantic to force filecheck to read stderr and see the warnings.
+float2 WarnMain(float it: A, float em: B)
+{
+  return RunTest(it, em);
+}
+