瀏覽代碼

Not init undef when recover init value for patch constant function. (#3141)

* Not init undef when recover init value for patch constant function.
Also lower ConstantAggregateZero directly when lower matrix.
Xiang Li 4 年之前
父節點
當前提交
8266d38b23

+ 3 - 0
lib/HLSL/HLMatrixLowerPass.cpp

@@ -370,6 +370,9 @@ Value* HLMatrixLowerPass::getLoweredByValOperand(Value *Val, IRBuilder<> &Builde
       return LoweredVal;
     }
   }
+  // Lower mat 0 to vec 0.
+  if (isa<ConstantAggregateZero>(Val))
+    return ConstantAggregateZero::get(LoweredTy);
 
   // Return a mat-to-vec translation stub
   FunctionType *TranslationStubTy = FunctionType::get(LoweredTy, { Ty }, /* isVarArg */ false);

+ 2 - 0
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -3353,6 +3353,8 @@ void CGMSHLSLRuntime::FinishCodeGen() {
           if (GV.getName() == "llvm.global_ctors")
             continue;
           Value *V = GV.getInitializer();
+          if (isa<UndefValue>(V))
+            continue;
           B.CreateStore(V, &GV);
         }
       }

+ 78 - 0
tools/clang/test/HLSLFileCheck/hlsl/types/matrix/hs_static_mat.hlsl

@@ -0,0 +1,78 @@
+// RUN: %dxc -E main -T hs_6_0 %s | FileCheck %s
+
+// Make sure not crash.
+// CHECK:define void @"\01?patchConstantF@@YA?AUPatchConstantOutput@@V?$OutputPatch@UHsOutput@@$02@@@Z"()
+// CHECK:define void @main()
+struct ST
+{
+ float4x4 m;
+
+};
+
+cbuffer S : register (b1)
+{
+ float4x4 m;
+ };
+
+static const struct {
+
+ float4x4 m;
+} M = {m};
+
+static ST staticST;
+
+ST GetM()
+{
+ ST Result;
+ Result.m= M.m;
+ return Result; 
+}
+
+ST GetST()
+{
+ return GetM();
+}
+
+ struct HsInput
+ {
+  float4x4 mat : MAT;
+  float4 Position : SV_Position;
+ };
+
+struct HsOutput
+{
+ float4x4 mat : MAT;
+  float4 Position : SV_Position;
+};
+
+struct PatchConstantOutput
+{
+
+ float TessFactor[3] : SV_TessFactor;
+ float InsideTessFactor : SV_InsideTessFactor;
+ float4 m : M;
+};
+
+PatchConstantOutput patchConstantF( const OutputPatch<HsOutput, 3 > I )
+ {
+  PatchConstantOutput O = (PatchConstantOutput)0;
+
+   staticST = GetST();
+   O.m = mul (staticST.m, float4(1,3,2,3));
+   return O;
+}
+
+ [domain("tri")]
+ [patchconstantfunc("patchConstantF")]
+ [outputcontrolpoints( 3 )]
+ [maxtessfactor(12)]
+ [partitioning("fractional_even")][outputtopology("triangle_cw")]
+ HsOutput main( InputPatch< HsInput , 12 > I, uint ControlPointID : SV_OutputControlPointID )
+ {
+
+staticST = GetST();
+
+HsOutput O = (HsOutput) 0;
+  O.mat = staticST.m;
+return O;
+}