Browse Source

Refactoring DxilEmitMetadata pass.

Xiang Li 8 years ago
parent
commit
46e47f01e5

+ 2 - 0
include/dxc/HLSL/DxilGenerationPass.h

@@ -46,6 +46,7 @@ ModulePass *createDxilFinalizeModulePass();
 ModulePass *createDxilEmitMetadataPass();
 FunctionPass *createDxilExpandTrigIntrinsicsPass();
 ModulePass *createDxilLoadMetadataPass();
+ModulePass *createDxilDeadFunctionEliminationPass();
 ModulePass *createDxilPrecisePropagatePass();
 FunctionPass *createDxilPreserveAllOutputsPass();
 FunctionPass *createDxilLegalizeResourceUsePass();
@@ -68,6 +69,7 @@ void initializeDxilFinalizeModulePass(llvm::PassRegistry&);
 void initializeDxilEmitMetadataPass(llvm::PassRegistry&);
 void initializeDxilExpandTrigIntrinsicsPass(llvm::PassRegistry&);
 void initializeDxilLoadMetadataPass(llvm::PassRegistry&);
+void initializeDxilDeadFunctionEliminationPass(llvm::PassRegistry&);
 void initializeDxilPrecisePropagatePassPass(llvm::PassRegistry&);
 void initializeDxilPreserveAllOutputsPass(llvm::PassRegistry&);
 void initializeDxilLegalizeResourceUsePassPass(llvm::PassRegistry&);

+ 1 - 0
lib/HLSL/DxcOptimizer.cpp

@@ -84,6 +84,7 @@ HRESULT SetupRegistryPassForHLSL() {
     initializeDeadInstEliminationPass(Registry);
     initializeDxilAddPixelHitInstrumentationPass(Registry);
     initializeDxilCondenseResourcesPass(Registry);
+    initializeDxilDeadFunctionEliminationPass(Registry);
     initializeDxilEliminateOutputDynamicIndexingPass(Registry);
     initializeDxilEmitMetadataPass(Registry);
     initializeDxilExpandTrigIntrinsicsPass(Registry);

+ 2 - 0
lib/HLSL/DxilLinker.cpp

@@ -618,6 +618,7 @@ void DxilLinkJob::RunPreparePass(Module &M) {
   legacy::PassManager PM;
 
   PM.add(createAlwaysInlinerPass(/*InsertLifeTime*/ false));
+  PM.add(createDxilDeadFunctionEliminationPass());
   // mem2reg.
   PM.add(createPromoteMemoryToRegisterPass());
   // Remove unused functions.
@@ -630,6 +631,7 @@ void DxilLinkJob::RunPreparePass(Module &M) {
   PM.add(createDxilCondenseResourcesPass());
   PM.add(createDxilFinalizeModulePass());
   PM.add(createComputeViewIdStatePass());
+  PM.add(createDxilDeadFunctionEliminationPass());
   PM.add(createDxilEmitMetadataPass());
 
   PM.run(M);

+ 156 - 124
lib/HLSL/DxilPreparePasses.cpp

@@ -92,6 +92,53 @@ ModulePass *llvm::createDxilLoadMetadataPass() {
 
 INITIALIZE_PASS(DxilLoadMetadata, "hlsl-dxilload", "HLSL load DxilModule from metadata", false, false)
 
+///////////////////////////////////////////////////////////////////////////////
+
+namespace {
+class DxilDeadFunctionElimination : public ModulePass {
+public:
+  static char ID; // Pass identification, replacement for typeid
+  explicit DxilDeadFunctionElimination () : ModulePass(ID) {}
+
+  const char *getPassName() const override { return "Remove all unused function except entry from DxilModule"; }
+
+  bool runOnModule(Module &M) override {
+    if (M.HasDxilModule()) {
+      DxilModule &DM = M.GetDxilModule();
+
+      bool IsLib = DM.GetShaderModel()->IsLib();
+      // Remove unused functions except entry and patch constant func.
+      // For library profile, only remove unused external functions.
+      Function *EntryFunc = DM.GetEntryFunction();
+      Function *PatchConstantFunc = DM.GetPatchConstantFunction();
+
+      std::vector<Function *> deadList;
+      for (iplist<Function>::iterator F : M.getFunctionList()) {
+        if (&(*F) == EntryFunc || &(*F) == PatchConstantFunc)
+          continue;
+        if (F->isDeclaration() || !IsLib) {
+          if (F->user_empty())
+            deadList.emplace_back(F);
+        }
+      }
+      bool bUpdated = deadList.size();
+      for (Function *F : deadList)
+        F->eraseFromParent();
+      return bUpdated;
+    }
+
+    return false;
+  }
+};
+}
+
+char DxilDeadFunctionElimination::ID = 0;
+
+ModulePass *llvm::createDxilDeadFunctionEliminationPass() {
+  return new DxilDeadFunctionElimination();
+}
+
+INITIALIZE_PASS(DxilDeadFunctionElimination, "dxil-dfe", "Remove all unused function except entry from DxilModule", false, false)
 
 ///////////////////////////////////////////////////////////////////////////////
 
@@ -205,11 +252,9 @@ public:
   bool runOnModule(Module &M) override {
     if (M.HasDxilModule()) {
       DxilModule &DM = M.GetDxilModule();
-      // Remove store undef output.
-      hlsl::OP *hlslOP = M.GetDxilModule().GetOP();
 
       bool IsLib = DM.GetShaderModel()->IsLib();
-      // Skip patch for lib.
+      // Skip validation patch for lib.
       if (!IsLib) {
         unsigned ValMajor = 0;
         unsigned ValMinor = 0;
@@ -219,148 +264,135 @@ public:
         }
       }
 
-      for (iplist<Function>::iterator F : M.getFunctionList()) {
-        if (!hlslOP->IsDxilOpFunc(F))
-          continue;
+      // Remove store undef output.
+      hlsl::OP *hlslOP = M.GetDxilModule().GetOP();
+      RemoveStoreUndefOutput(M, hlslOP);
 
-        // Check store output.
-        FunctionType *FT = F->getFunctionType();
-        // Num params not match.
-        if (FT->getNumParams() !=
-            (DXIL::OperandIndex::kStoreOutputValOpIdx + 1))
-          continue;
+      RemoveUnusedStaticGlobal(M);
 
-        Type *overloadTy =
-            FT->getParamType(DXIL::OperandIndex::kStoreOutputValOpIdx);
-        // overload illegal.
-        if (!hlslOP->IsOverloadLegal(DXIL::OpCode::StoreOutput, overloadTy))
-          continue;
-        Function *storeOutput =
-            hlslOP->GetOpFunc(DXIL::OpCode::StoreOutput, overloadTy);
-        // Not store output.
-        if (storeOutput != F)
-          continue;
+      // Clear inbound for GEP which has none-const index.
+      LegalizeShareMemoryGEPInbound(M);
 
-        for (auto it = F->user_begin(); it != F->user_end();) {
-          CallInst *CI = dyn_cast<CallInst>(*(it++));
-          if (!CI)
-            continue;
+      // Strip parameters of entry function.
+      StripEntryParameters(M, DM, IsLib);
 
-          Value *V =
-              CI->getArgOperand(DXIL::OperandIndex::kStoreOutputValOpIdx);
-          // Remove the store of undef.
-          if (isa<UndefValue>(V))
-            CI->eraseFromParent();
-        }
+      // Skip shader flag for library.
+      if (!IsLib) {
+        DM.CollectShaderFlags(); // Update flags to reflect any changes.
+                                 // Update Validator Version
+        DM.UpgradeToMinValidatorVersion();
       }
-      // Remove unused external functions.
-      // For none library profile, remove unused functions except entry and
-      // patchconstant function.
-      Function *EntryFunc = DM.GetEntryFunction();
-      Function *PatchConstantFunc = DM.GetPatchConstantFunction();
+      return true;
+    }
 
-      std::vector<Function *> deadList;
-      for (iplist<Function>::iterator F : M.getFunctionList()) {
-        if (&(*F) == EntryFunc || &(*F) == PatchConstantFunc)
+    return false;
+  }
+
+private:
+  void RemoveUnusedStaticGlobal(Module &M) {
+    // Remove unused internal global.
+    std::vector<GlobalVariable *> staticGVs;
+    for (GlobalVariable &GV : M.globals()) {
+      if (dxilutil::IsStaticGlobal(&GV) ||
+          dxilutil::IsSharedMemoryGlobal(&GV)) {
+        staticGVs.emplace_back(&GV);
+      }
+    }
+
+    for (GlobalVariable *GV : staticGVs) {
+      bool onlyStoreUse = true;
+      for (User *user : GV->users()) {
+        if (isa<StoreInst>(user))
           continue;
-        if (F->isDeclaration() || !IsLib) {
-          if (F->user_empty())
-            deadList.emplace_back(F);
+        if (isa<ConstantExpr>(user) && user->user_empty())
+          continue;
+        onlyStoreUse = false;
+        break;
+      }
+      if (onlyStoreUse) {
+        for (auto UserIt = GV->user_begin(); UserIt != GV->user_end();) {
+          Value *User = *(UserIt++);
+          if (Instruction *I = dyn_cast<Instruction>(User)) {
+            I->eraseFromParent();
+          } else {
+            ConstantExpr *CE = cast<ConstantExpr>(User);
+            CE->dropAllReferences();
+          }
         }
+        GV->eraseFromParent();
       }
+    }
+  }
 
-      for (Function *F : deadList)
-        F->eraseFromParent();
+  void RemoveStoreUndefOutput(Module &M, hlsl::OP *hlslOP) {
+    for (iplist<Function>::iterator F : M.getFunctionList()) {
+      if (!hlslOP->IsDxilOpFunc(F))
+        continue;
+      DXIL::OpCodeClass opClass;
+      DXASSERT(hlslOP->GetOpCodeClass(F, opClass), "else not a dxil op func");
+      if (opClass != DXIL::OpCodeClass::StoreOutput)
+        continue;
+
+      for (auto it = F->user_begin(); it != F->user_end();) {
+        CallInst *CI = dyn_cast<CallInst>(*(it++));
+        if (!CI)
+          continue;
 
-      // Remove unused internal global.
-      std::vector<GlobalVariable *> staticGVs;
-      for (GlobalVariable &GV : M.globals()) {
-        if (dxilutil::IsStaticGlobal(&GV) ||
-            dxilutil::IsSharedMemoryGlobal(&GV)) {
-          staticGVs.emplace_back(&GV);
-        }
+        Value *V = CI->getArgOperand(DXIL::OperandIndex::kStoreOutputValOpIdx);
+        // Remove the store of undef.
+        if (isa<UndefValue>(V))
+          CI->eraseFromParent();
       }
+    }
+  }
 
-      for (GlobalVariable *GV : staticGVs) {
-        bool onlyStoreUse = true;
-        for (User *user : GV->users()) {
-          if (isa<StoreInst>(user))
-            continue;
-          if (isa<ConstantExpr>(user) && user->user_empty())
-            continue;
-          onlyStoreUse = false;
-          break;
-        }
-        if (onlyStoreUse) {
-          for (auto UserIt = GV->user_begin(); UserIt != GV->user_end();) {
-            Value *User = *(UserIt++);
-            if (Instruction *I = dyn_cast<Instruction>(User)) {
-              I->eraseFromParent();
-            } else {
-              ConstantExpr *CE = cast<ConstantExpr>(User);
-              CE->dropAllReferences();
-            }
-          }
-          GV->eraseFromParent();
-        }
+  void LegalizeShareMemoryGEPInbound(Module &M) {
+    const DataLayout &DL = M.getDataLayout();
+    // Clear inbound for GEP which has none-const index.
+    for (GlobalVariable &GV : M.globals()) {
+      if (dxilutil::IsSharedMemoryGlobal(&GV)) {
+        CheckInBoundForTGSM(GV, DL);
       }
+    }
+  }
 
-      const DataLayout &DL = M.getDataLayout();
-      // Clear inbound for GEP which has none-const index.
-      for (GlobalVariable &GV : M.globals()) {
-        if (dxilutil::IsSharedMemoryGlobal(&GV)) {
-          CheckInBoundForTGSM(GV, DL);
-        }
+  void StripEntryParameters(Module &M, DxilModule &DM, bool IsLib) {
+    DenseMap<const Function *, DISubprogram *> FunctionDIs =
+        makeSubprogramMap(M);
+    // Strip parameters of entry function.
+    if (!IsLib) {
+      if (Function *PatchConstantFunc = DM.GetPatchConstantFunction()) {
+        PatchConstantFunc =
+            StripFunctionParameter(PatchConstantFunc, DM, FunctionDIs);
+        if (PatchConstantFunc)
+          DM.SetPatchConstantFunction(PatchConstantFunc);
       }
 
-      DenseMap<const Function *, DISubprogram *> FunctionDIs =
-          makeSubprogramMap(M);
-      // Strip parameters of entry function.
-      if (!IsLib) {
-        if (Function *PatchConstantFunc = DM.GetPatchConstantFunction()) {
-          PatchConstantFunc =
-              StripFunctionParameter(PatchConstantFunc, DM, FunctionDIs);
-          if (PatchConstantFunc)
-            DM.SetPatchConstantFunction(PatchConstantFunc);
-        }
-
-        if (Function *EntryFunc = DM.GetEntryFunction()) {
-          StringRef Name = DM.GetEntryFunctionName();
-          EntryFunc->setName(Name);
-          EntryFunc = StripFunctionParameter(EntryFunc, DM, FunctionDIs);
-          if (EntryFunc)
-            DM.SetEntryFunction(EntryFunc);
-        }
-      } else {
-        std::vector<Function *> entries;
-        for (iplist<Function>::iterator F : M.getFunctionList()) {
-          if (DM.HasDxilFunctionProps(F)) {
-            entries.emplace_back(F);
-          }
-        }
-        for (Function *entry : entries) {
-          DxilFunctionProps &props = DM.GetDxilFunctionProps(entry);
-          if (props.IsHS()) {
-            // Strip patch constant function first.
-            Function *patchConstFunc = StripFunctionParameter(
-                props.ShaderProps.HS.patchConstantFunc, DM, FunctionDIs);
-            props.ShaderProps.HS.patchConstantFunc = patchConstFunc;
-          }
-          StripFunctionParameter(entry, DM, FunctionDIs);
+      if (Function *EntryFunc = DM.GetEntryFunction()) {
+        StringRef Name = DM.GetEntryFunctionName();
+        EntryFunc->setName(Name);
+        EntryFunc = StripFunctionParameter(EntryFunc, DM, FunctionDIs);
+        if (EntryFunc)
+          DM.SetEntryFunction(EntryFunc);
+      }
+    } else {
+      std::vector<Function *> entries;
+      for (iplist<Function>::iterator F : M.getFunctionList()) {
+        if (DM.HasDxilFunctionProps(F)) {
+          entries.emplace_back(F);
         }
       }
-
-      // Skip shader flag for library.
-      if (!IsLib) {
-        DM.CollectShaderFlags(); // Update flags to reflect any changes.
-                                 // Update Validator Version
-        DM.UpgradeToMinValidatorVersion();
+      for (Function *entry : entries) {
+        DxilFunctionProps &props = DM.GetDxilFunctionProps(entry);
+        if (props.IsHS()) {
+          // Strip patch constant function first.
+          Function *patchConstFunc = StripFunctionParameter(
+              props.ShaderProps.HS.patchConstantFunc, DM, FunctionDIs);
+          props.ShaderProps.HS.patchConstantFunc = patchConstFunc;
+        }
+        StripFunctionParameter(entry, DM, FunctionDIs);
       }
-
-      return true;
     }
-
-    return false;
   }
 };
 }

+ 6 - 0
lib/Transforms/IPO/PassManagerBuilder.cpp

@@ -210,6 +210,10 @@ static void addHLSLPasses(bool HLSLHighLevel, bool NoOpt, hlsl::HLSLExtensionsCo
     return;
   }
 
+  if (!NoOpt) {
+    MPM.add(createDxilDeadFunctionEliminationPass());
+  }
+
   // Split struct and array of parameter.
   MPM.add(createSROA_Parameter_HLSL());
 
@@ -295,6 +299,7 @@ void PassManagerBuilder::populateModulePassManager(
       MPM.add(createDxilLegalizeSampleOffsetPass()); // HLSL Change
       MPM.add(createDxilFinalizeModulePass());      // HLSL Change
       MPM.add(createComputeViewIdStatePass());    // HLSL Change
+      MPM.add(createDxilDeadFunctionEliminationPass()); // HLSL Change
       MPM.add(createDxilEmitMetadataPass());      // HLSL Change
     }
     // HLSL Change Ends.
@@ -566,6 +571,7 @@ void PassManagerBuilder::populateModulePassManager(
       MPM.add(createDxilLegalizeSampleOffsetPass()); // HLSL Change
     MPM.add(createDxilFinalizeModulePass());
     MPM.add(createComputeViewIdStatePass()); // HLSL Change
+    MPM.add(createDxilDeadFunctionEliminationPass()); // HLSL Change
     MPM.add(createDxilEmitMetadataPass());
   }
   // HLSL Change Ends.

+ 1 - 0
utils/hct/hctdb.py

@@ -1285,6 +1285,7 @@ class db_dxil(object):
         add_pass('hlsl-dxilfinalize', 'DxilFinalizeModule', 'HLSL DXIL Finalize Module', [])
         add_pass('hlsl-dxilemit', 'DxilEmitMetadata', 'HLSL DXIL Metadata Emit', [])
         add_pass('hlsl-dxilload', 'DxilLoadMetadata', 'HLSL DXIL Metadata Load', [])
+        add_pass('dxil-dfe', 'DxilDeadFunctionElimination', 'Remove all unused function except entry from DxilModule', [])
         add_pass('hlsl-dxil-expand-trig', 'DxilExpandTrigIntrinsics', 'DXIL expand trig intrinsics', [])
         add_pass('hlsl-hca', 'HoistConstantArray', 'HLSL constant array hoisting', [])
         add_pass('hlsl-dxil-preserve-all-outputs', 'DxilPreserveAllOutputs', 'DXIL write to all outputs in signature', [])