Browse Source

Fix nondeterminism around patch constant functions (#2099)

When multiple hull shaders point to a single patch constant buffer, we would iterate over the hull shaders, during which we would apply a transformation on the patch constant buffer and erase it. This means that subsequent hull shaders will be pointing to an already-deleted patch constant buffer. If the pointer got reused by a new function, we got into weird situations were we would remove a newer function while trying to clean an invalidated pointer.

The solution is to bin hull shaders by their common patch constant function. Transform the patch constant function without deleting the original, performing the replace on every using hull shader function, and only then deleting the original function.
Tristan Labelle 6 years ago
parent
commit
4167dd46f2

+ 50 - 28
lib/HLSL/DxilPreparePasses.cpp

@@ -171,6 +171,7 @@ static void TransferEntryFunctionAttributes(Function *F, Function *NewFunc) {
     NewFunc->addFnAttr(attrKind, attrValue);
     NewFunc->addFnAttr(attrKind, attrValue);
 }
 }
 
 
+// If this returns non-null, the old function F has been stripped and can be deleted.
 static Function *StripFunctionParameter(Function *F, DxilModule &DM,
 static Function *StripFunctionParameter(Function *F, DxilModule &DM,
     DenseMap<const Function *, DISubprogram *> &FunctionDIs) {
     DenseMap<const Function *, DISubprogram *> &FunctionDIs) {
   if (F->arg_empty() && F->getReturnType()->isVoidTy()) {
   if (F->arg_empty() && F->getReturnType()->isVoidTy()) {
@@ -213,7 +214,6 @@ static Function *StripFunctionParameter(Function *F, DxilModule &DM,
     DM.ReplaceDxilEntryProps(F, NewFunc);
     DM.ReplaceDxilEntryProps(F, NewFunc);
   }
   }
   DM.GetTypeSystem().EraseFunctionAnnotation(F);
   DM.GetTypeSystem().EraseFunctionAnnotation(F);
-  F->eraseFromParent();
   DM.GetTypeSystem().AddFunctionAnnotation(NewFunc);
   DM.GetTypeSystem().AddFunctionAnnotation(NewFunc);
   return NewFunc;
   return NewFunc;
 }
 }
@@ -399,50 +399,72 @@ private:
         makeSubprogramMap(M);
         makeSubprogramMap(M);
     // Strip parameters of entry function.
     // Strip parameters of entry function.
     if (!IsLib) {
     if (!IsLib) {
-      if (Function *PatchConstantFunc = DM.GetPatchConstantFunction()) {
-        PatchConstantFunc =
-            StripFunctionParameter(PatchConstantFunc, DM, FunctionDIs);
-        if (PatchConstantFunc) {
-          DM.SetPatchConstantFunction(PatchConstantFunc);
+      if (Function *OldPatchConstantFunc = DM.GetPatchConstantFunction()) {
+        Function *NewPatchConstantFunc =
+            StripFunctionParameter(OldPatchConstantFunc, DM, FunctionDIs);
+        if (NewPatchConstantFunc) {
+          DM.SetPatchConstantFunction(NewPatchConstantFunc);
+
+          // Erase once the DxilModule doesn't track the old function anymore
+          DXASSERT(DM.IsPatchConstantShader(NewPatchConstantFunc) && !DM.IsPatchConstantShader(OldPatchConstantFunc),
+            "Error while migrating to parameter-stripped patch constant function.");
+          OldPatchConstantFunc->eraseFromParent();
         }
         }
       }
       }
 
 
-      if (Function *EntryFunc = DM.GetEntryFunction()) {
+      if (Function *OldEntryFunc = DM.GetEntryFunction()) {
         StringRef Name = DM.GetEntryFunctionName();
         StringRef Name = DM.GetEntryFunctionName();
-        EntryFunc->setName(Name);
-        EntryFunc = StripFunctionParameter(EntryFunc, DM, FunctionDIs);
-        if (EntryFunc) {
-          DM.SetEntryFunction(EntryFunc);
+        OldEntryFunc->setName(Name);
+        Function *NewEntryFunc = StripFunctionParameter(OldEntryFunc, DM, FunctionDIs);
+        if (NewEntryFunc) {
+          DM.SetEntryFunction(NewEntryFunc);
+          OldEntryFunc->eraseFromParent();
         }
         }
       }
       }
     } else {
     } else {
       std::vector<Function *> entries;
       std::vector<Function *> entries;
       // Handle when multiple hull shaders point to the same patch constant function
       // Handle when multiple hull shaders point to the same patch constant function
-      DenseMap<Function*,Function*> patchConstantUpdates;
+      MapVector<Function*, llvm::SmallVector<Function*, 2>> PatchConstantFuncUsers;
       for (iplist<Function>::iterator F : M.getFunctionList()) {
       for (iplist<Function>::iterator F : M.getFunctionList()) {
         if (DM.IsEntryThatUsesSignatures(F)) {
         if (DM.IsEntryThatUsesSignatures(F)) {
           auto *FT = F->getFunctionType();
           auto *FT = F->getFunctionType();
           // Only do this when has parameters.
           // Only do this when has parameters.
-          if (FT->getNumParams() > 0 || !FT->getReturnType()->isVoidTy())
+          if (FT->getNumParams() > 0 || !FT->getReturnType()->isVoidTy()) {
             entries.emplace_back(F);
             entries.emplace_back(F);
+          }
+
+          DxilFunctionProps& props = DM.GetDxilFunctionProps(F);
+          if (props.IsHS() && props.ShaderProps.HS.patchConstantFunc) {
+            FunctionType* PatchConstantFuncTy = props.ShaderProps.HS.patchConstantFunc->getFunctionType();
+            if (PatchConstantFuncTy->getNumParams() > 0 || !PatchConstantFuncTy->getReturnType()->isVoidTy()) {
+              // Accumulate all hull shaders using a given patch constant function,
+              // so we can update it once and fix all hull shaders, without having an intermediary
+              // state where some hull shaders point to a destroyed patch constant function.
+              PatchConstantFuncUsers[props.ShaderProps.HS.patchConstantFunc].emplace_back(F);
+            }
+          }
         }
         }
       }
       }
-      for (Function *entry : entries) {
-        DxilFunctionProps &props = DM.GetDxilFunctionProps(entry);
-        if (props.IsHS()) {
-          // Strip patch constant function first.
-          Function* patchConstFunc = props.ShaderProps.HS.patchConstantFunc;
-          auto it = patchConstantUpdates.find(patchConstFunc);
-          if (it == patchConstantUpdates.end()) {
-            patchConstFunc = patchConstantUpdates[patchConstFunc] =
-                StripFunctionParameter(patchConstFunc, DM, FunctionDIs);
-          } else {
-            patchConstFunc = it->second;
-          }
-          if (patchConstFunc)
-            DM.SetPatchConstantFunctionForHS(entry, patchConstFunc);
+
+      // Strip patch constant functions first
+      for (auto &PatchConstantFuncEntry : PatchConstantFuncUsers) {
+        Function* OldPatchConstantFunc = PatchConstantFuncEntry.first;
+        Function* NewPatchConstantFunc = StripFunctionParameter(OldPatchConstantFunc, DM, FunctionDIs);
+        if (NewPatchConstantFunc) {
+          // Update all user hull shaders
+          for (Function *HullShaderFunc : PatchConstantFuncEntry.second)
+            DM.SetPatchConstantFunctionForHS(HullShaderFunc, NewPatchConstantFunc);
+
+          // Erase once the DxilModule doesn't track the old function anymore
+          DXASSERT(DM.IsPatchConstantShader(NewPatchConstantFunc) && !DM.IsPatchConstantShader(OldPatchConstantFunc),
+            "Error while migrating to parameter-stripped patch constant function.");
+          OldPatchConstantFunc->eraseFromParent();
         }
         }
-        StripFunctionParameter(entry, DM, FunctionDIs);
+      }
+
+      for (Function *OldEntry : entries) {
+        Function *NewEntry = StripFunctionParameter(OldEntry, DM, FunctionDIs);
+        if (NewEntry) OldEntry->eraseFromParent();
       }
       }
     }
     }
   }
   }

+ 0 - 0
tools/clang/test/CodeGenHLSL/batch/misc/d3dreflect/lib_hs_export1.hlsl.disabled → tools/clang/test/CodeGenHLSL/batch/misc/d3dreflect/lib_hs_export1.hlsl


+ 0 - 0
tools/clang/test/CodeGenHLSL/batch/shader_stages/library/lib_hs_shaders_only.hlsl.disabled → tools/clang/test/CodeGenHLSL/batch/shader_stages/library/lib_hs_shaders_only.hlsl