Jelajahi Sumber

Fix use constant corruption from ReplaceUsesForLoweredUDT (#4452)

ReplaceUsesForLoweredUDT was calling use.set(NewV), or calling dropAllReferences() on constant user.
This results in modification of the constant operand, which is illegal and led to corruption of constants.

Loop in ReplaceUsesForLoweredUDT relies on all uses being eliminated to terminate, so constant or potentially constant users are now handled with legal replacements and removeDeadConstantUsers at the end as needed.
Tex Riddell 3 tahun lalu
induk
melakukan
5118e1876c

+ 16 - 6
lib/HLSL/HLLowerUDT.cpp

@@ -201,8 +201,9 @@ void hlsl::ReplaceUsesForLoweredUDT(Value *V, Value *NewV) {
   while (!V->use_empty()) {
     Use &use = *V->use_begin();
     User *user = use.getUser();
-    // Clear use to prevent infinite loop on unhandled case.
-    use.set(UndefValue::get(V->getType()));
+    if (Instruction *I = dyn_cast<Instruction>(user)) {
+      use.set(UndefValue::get(I->getType()));
+    }
 
     if (LoadInst *LI = dyn_cast<LoadInst>(user)) {
       // Load for non-matching type should only be vector
@@ -250,7 +251,6 @@ void hlsl::ReplaceUsesForLoweredUDT(Value *V, Value *NewV) {
       Constant *NewGEP = ConstantExpr::getGetElementPtr(
         nullptr, cast<Constant>(NewV), idxList, true);
       ReplaceUsesForLoweredUDT(GEP, NewGEP);
-      GEP->dropAllReferences();
 
     } else if (AddrSpaceCastInst *AC = dyn_cast<AddrSpaceCastInst>(user)) {
       // Address space cast
@@ -267,6 +267,7 @@ void hlsl::ReplaceUsesForLoweredUDT(Value *V, Value *NewV) {
         // if alreday bitcast to new type, just replace the bitcast
         // with the new value (already translated user function)
         BC->replaceAllUsesWith(NewV);
+        BC->eraseFromParent();
       } else {
         // Could be i8 for memcpy?
         // Replace bitcast argument with new value
@@ -288,11 +289,13 @@ void hlsl::ReplaceUsesForLoweredUDT(Value *V, Value *NewV) {
         } else {
           // Could be i8 for memcpy?
           // Replace bitcast argument with new value
-          use.set(NewV);
+          CE->replaceAllUsesWith(
+              ConstantExpr::getBitCast(cast<Constant>(NewV), CE->getType()));
         }
       } else {
-        DXASSERT(0, "unhandled constant expr for lowered UTD");
-        CE->dropAllReferences();  // better than infinite loop on release
+        DXASSERT(0, "unhandled constant expr for lowered UDT");
+        // better than infinite loop on release
+        CE->replaceAllUsesWith(UndefValue::get(CE->getType()));
       }
 
     } else if (CallInst *CI = dyn_cast<CallInst>(user)) {
@@ -430,10 +433,17 @@ void hlsl::ReplaceUsesForLoweredUDT(Value *V, Value *NewV) {
 
       default:
         DXASSERT(0, "invalid opcode");
+        // Replace user with undef to prevent infinite loop on unhandled case.
+        user->replaceAllUsesWith(UndefValue::get(user->getType()));
       }
     } else {
       // What else?
       DXASSERT(false, "case not handled.");
+      // Replace user with undef to prevent infinite loop on unhandled case.
+      user->replaceAllUsesWith(UndefValue::get(user->getType()));
     }
+    // Clean up dead constant users to prevent infinite loop
+    if (Constant *CV = dyn_cast<Constant>(V))
+      CV->removeDeadConstantUsers();
   }
 }

+ 29 - 0
tools/clang/test/HLSLFileCheck/shader_targets/mesh/as-groupshared-payload-method.hlsl

@@ -0,0 +1,29 @@
+// RUN: %dxc -T as_6_5 %s | FileCheck %s
+
+// Ensure groupshared payload still accepted when initialized with method.
+// CHECK: @[[g_payload:.*]] = addrspace(3) global
+// CHECK: store i32 {{.*}} i32 addrspace(3)*
+// CHECK: store i32 {{.*}} i32 addrspace(3)*
+// CHECK: call void @dx.op.dispatchMesh
+// CHECK-SAME: addrspace(3)* nonnull @[[g_payload]]
+
+struct SharedPayload
+{
+  uint2 m_a;
+
+  void Foo( in float3 v3 )
+  {
+    uint3 f16Vec3 = f32tof16(v3);
+    m_a.x = f16Vec3.x | (f16Vec3.y<<16);
+    m_a.y = f16Vec3.z;
+  }
+};
+
+groupshared SharedPayload g_payload;
+
+[numthreads(8, 8, 1)]
+void main()
+{
+  g_payload.Foo( 1.0.xxx );
+  DispatchMesh(1,1,1,g_payload);
+}