浏览代码

Fix crash when write to geometry shader input. (#36)

Xiang Li 8 年之前
父节点
当前提交
deb9f3fd27

+ 23 - 5
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -4069,8 +4069,11 @@ void SROA_Parameter_HLSL::flattenArgument(
                 DXASSERT(data->getType()->isPointerTy(),
                          "Append value must be pointer.");
                 IRBuilder<> Builder(CI);
-                Value *ldInst = Builder.CreateLoad(data);
-                Builder.CreateStore(ldInst, outputVal);
+
+                llvm::SmallVector<llvm::Value *, 16> idxList;
+                SplitCpy(data->getType(), outputVal, data, idxList,
+                         /*bAllowReplace*/ false, Builder);
+
                 CI->setArgOperand(HLOperandIndex::kStreamAppendDataOpIndex, outputVal);
               }
               else {
@@ -4089,9 +4092,13 @@ void SROA_Parameter_HLSL::flattenArgument(
                 DXASSERT_LOCALVAR(eltCount, eltCount == EltPtrList.size(), "invalid element count");
 
                 for (unsigned i = HLOperandIndex::kStreamAppendDataOpIndex; i < CI->getNumArgOperands(); i++) {
-                  Value *Elt = Builder.CreateLoad(CI->getArgOperand(i));
-                  Value *EltPtr = EltPtrList[i-HLOperandIndex::kStreamAppendDataOpIndex];
-                  Builder.CreateStore(Elt, EltPtr);
+                  Value *DataPtr = CI->getArgOperand(i);
+                  Value *EltPtr =
+                      EltPtrList[i - HLOperandIndex::kStreamAppendDataOpIndex];
+
+                  llvm::SmallVector<llvm::Value *, 16> idxList;
+                  SplitCpy(DataPtr->getType(), EltPtr, DataPtr, idxList,
+                           /*bAllowReplace*/ false, Builder);
                   CI->setArgOperand(i, EltPtr);
                 }
               }
@@ -4255,6 +4262,17 @@ static void LegalizeDxilInputOutputs(Function *F, DxilFunctionAnnotation *EntryA
       bNeedTemp = true;
       bLoadOutputFromTemp = true;
       bStoreInputToTemp = true;
+    } else if (bLoad && bStore) {
+      bNeedTemp = true;
+      switch (qual) {
+      case DxilParamInputQual::InputPrimitive:
+      case DxilParamInputQual::InputPatch:
+      case DxilParamInputQual::OutputPatch:
+        bStoreInputToTemp = true;
+        break;
+      default:
+        DXASSERT(0, "invalid input qual here");
+      }
     }
 
     if (HLMatrixLower::IsMatrixType(Ty)) {

+ 29 - 0
tools/clang/test/CodeGenHLSL/SimpleGS5.hlsl

@@ -0,0 +1,29 @@
+// RUN: %dxc -E main -T gs_6_0 %s | FileCheck %s
+
+// CHECK: InputPrimitive=patch2
+// CHECK: emitStream
+// CHECK: cutStream
+// CHECK: i32 24}
+
+struct GSOut {
+  float2 uv : TEXCOORD0;
+  float4 clr : COLOR;
+  float4 pos : SV_Position;
+  float3 norm[2] : NORMAL;
+};
+
+cbuffer b : register(b0) {
+  float2 invViewportSize;
+};
+
+// geometry shader that outputs 3 vertices from a point
+[maxvertexcount(3)]
+[instance(24)]
+void main(InputPatch<GSOut, 2>points, inout PointStream<GSOut> stream) {
+
+  points[0].norm[0] = 1;
+  points[0].norm[1] = 2;
+  stream.Append(points[0]);
+
+  stream.RestartStrip();
+}

+ 5 - 0
tools/clang/unittests/HLSL/CompilerTest.cpp

@@ -461,6 +461,7 @@ public:
   TEST_METHOD(CodeGenSimpleGS2)
   TEST_METHOD(CodeGenSimpleGS3)
   TEST_METHOD(CodeGenSimpleGS4)
+  TEST_METHOD(CodeGenSimpleGS5)
   TEST_METHOD(CodeGenSimpleHS1)
   TEST_METHOD(CodeGenSimpleHS2)
   TEST_METHOD(CodeGenSimpleHS3)
@@ -2340,6 +2341,10 @@ TEST_F(CompilerTest, CodeGenSimpleGS4) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\SimpleGS4.hlsl");
 }
 
+TEST_F(CompilerTest, CodeGenSimpleGS5) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\SimpleGS5.hlsl");
+}
+
 TEST_F(CompilerTest, CodeGenSimpleHS1) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\SimpleHS1.hlsl");
 }