ソースを参照

Replace static const global struct for library. (#436)

* Replace static const global struct for library.
Also support CK_HLSLMatrixSplat in ConstExprEmitter::VisitCastExpr.
Xiang Li 8 年 前
コミット
2a7cd92683

+ 13 - 0
tools/clang/lib/CodeGen/CGExprConstant.cpp

@@ -750,6 +750,19 @@ public:
       std::vector<llvm::Constant*> Elts(vecSize, C);
       return llvm::ConstantVector::get(Elts);
     }
+    case CK_HLSLMatrixSplat: {
+      llvm::StructType *ST =
+          cast<llvm::StructType>(CGM.getTypes().ConvertType(E->getType()));
+      unsigned row,col;
+      hlsl::GetHLSLMatRowColCount(E->getType(), row, col);
+
+      std::vector<llvm::Constant *> Cols(col, C);
+      llvm::Constant *Row = llvm::ConstantVector::get(Cols);
+      std::vector<llvm::Constant *> Rows(row, Row);
+      llvm::Constant *Mat = llvm::ConstantArray::get(
+          cast<llvm::ArrayType>(ST->getElementType(0)), Rows);
+      return llvm::ConstantStruct::get(ST, Mat);
+    }
     // HLSL Change Ends.
     }
     llvm_unreachable("Invalid CastKind");

+ 120 - 2
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -122,6 +122,11 @@ private:
   // Map to save entry functions.
   StringMap<Function *> entryFunctionMap;
 
+  // Map to save static global init exp.
+  std::unordered_map<Expr *, GlobalVariable *> staticConstGlobalInitMap;
+  std::unordered_map<GlobalVariable *, std::vector<Constant *>>
+      staticConstGlobalInitListMap;
+  std::unordered_map<GlobalVariable *, Function *> staticConstGlobalCtorMap;
   // List for functions with clip plane.
   std::vector<Function *> clipPlaneFuncList;
   std::unordered_map<Value *, DebugLoc> debugInfoMap;
@@ -1851,8 +1856,17 @@ void CGMSHLSLRuntime::addResource(Decl *D) {
     if (VD->hasInit() && resClass != DXIL::ResourceClass::Invalid)
       return;
     // skip static global.
-    if (!VD->isExternallyVisible())
+    if (!VD->isExternallyVisible()) {
+      if (VD->hasInit() && VD->getType().isConstQualified()) {
+        Expr* InitExp = VD->getInit();
+        GlobalVariable *GV = cast<GlobalVariable>(CGM.GetAddrOfGlobalVar(VD));
+        // Only save const static global of struct type.
+        if (GV->getType()->getElementType()->isStructTy()) {
+          staticConstGlobalInitMap[InitExp] = GV;
+        }
+      }
       return;
+    }
 
     if (D->hasAttr<HLSLGroupSharedAttr>()) {
       GlobalVariable *GV = cast<GlobalVariable>(CGM.GetAddrOfGlobalVar(VD));
@@ -3852,6 +3866,78 @@ static void CloneShaderEntry(Function *ShaderF, StringRef EntryName,
   }
 }
 
+// For case like:
+//cbuffer A {
+//  float a;
+//  int b;
+//}
+//
+//const static struct {
+//  float a;
+//  int b;
+//}  ST = { a, b };
+// Replace user of ST with a and b.
+static bool ReplaceConstStaticGlobalUser(GEPOperator *GEP,
+                                         std::vector<Constant *> &InitList,
+                                         IRBuilder<> &Builder) {
+  if (GEP->getNumIndices() < 2) {
+    // Don't use sub element.
+    return false;
+  }
+
+  SmallVector<Value *, 4> idxList;
+  auto iter = GEP->idx_begin();
+  idxList.emplace_back(*(iter++));
+  ConstantInt *subIdx = dyn_cast<ConstantInt>(*(iter++));
+
+  DXASSERT(subIdx, "else dynamic indexing on struct field");
+  unsigned subIdxImm = subIdx->getLimitedValue();
+  DXASSERT(subIdxImm < InitList.size(), "else struct index out of bound");
+
+  Constant *subPtr = InitList[subIdxImm];
+  // Move every idx to idxList except idx for InitList.
+  while (iter != GEP->idx_end()) {
+    idxList.emplace_back(*(iter++));
+  }
+  Value *NewGEP = Builder.CreateGEP(subPtr, idxList);
+  GEP->replaceAllUsesWith(NewGEP);
+  return true;
+}
+
+static void ReplaceConstStaticGlobals(
+    std::unordered_map<GlobalVariable *, std::vector<Constant *>>
+        &staticConstGlobalInitListMap,
+    std::unordered_map<GlobalVariable *, Function *>
+        &staticConstGlobalCtorMap) {
+
+  for (auto &iter : staticConstGlobalInitListMap) {
+    GlobalVariable *GV = iter.first;
+    std::vector<Constant *> &InitList = iter.second;
+    LLVMContext &Ctx = GV->getContext();
+    // Do the replace.
+    bool bPass = true;
+    for (User *U : GV->users()) {
+      IRBuilder<> Builder(Ctx);
+      if (GetElementPtrInst *GEPInst = dyn_cast<GetElementPtrInst>(U)) {
+        Builder.SetInsertPoint(GEPInst);
+        bPass &= ReplaceConstStaticGlobalUser(cast<GEPOperator>(GEPInst), InitList, Builder);
+      } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(U)) {
+        bPass &= ReplaceConstStaticGlobalUser(GEP, InitList, Builder);
+      } else {
+        DXASSERT(false, "invalid user of const static global");
+      }
+    }
+    // Clear the Ctor which is useless now.
+    if (bPass) {
+      Function *Ctor = staticConstGlobalCtorMap[GV];
+      Ctor->getBasicBlockList().clear();
+      BasicBlock *Entry = BasicBlock::Create(Ctx, "", Ctor);
+      IRBuilder<> Builder(Entry);
+      Builder.CreateRetVoid();
+    }
+  }
+}
+
 void CGMSHLSLRuntime::FinishCodeGen() {
   // Library don't have entry.
   if (!m_bIsLib) {
@@ -3869,6 +3955,9 @@ void CGMSHLSLRuntime::FinishCodeGen() {
     }
   }
 
+  ReplaceConstStaticGlobals(staticConstGlobalInitListMap,
+                            staticConstGlobalCtorMap);
+
   // Create copy for clip plane.
   for (Function *F : clipPlaneFuncList) {
     DxilFunctionProps &props = m_pHLModule->GetDxilFunctionProps(F);
@@ -4587,7 +4676,36 @@ static bool ExpTypeMatch(Expr *E, QualType Ty, ASTContext &Ctx, CodeGenTypes &Ty
 bool CGMSHLSLRuntime::IsTrivalInitListExpr(CodeGenFunction &CGF,
                                            InitListExpr *E) {
   QualType Ty = E->getType();
-  return ExpTypeMatch(E, Ty, CGF.getContext(), CGF.getTypes());
+  bool result = ExpTypeMatch(E, Ty, CGF.getContext(), CGF.getTypes());
+  if (result) {
+    auto iter = staticConstGlobalInitMap.find(E);
+    if (iter != staticConstGlobalInitMap.end()) {
+      GlobalVariable * GV = iter->second;
+      auto &InitConstants = staticConstGlobalInitListMap[GV];
+      // Add Constant to InitList.
+      for (unsigned i=0;i<E->getNumInits();i++) {
+        Expr *Expr = E->getInit(i);
+        LValue LV = CGF.EmitLValue(Expr);
+        if (LV.isSimple()) {
+          Constant *SrcPtr = dyn_cast<Constant>(LV.getAddress());
+          if (SrcPtr && !isa<UndefValue>(SrcPtr)) {
+            InitConstants.emplace_back(SrcPtr);
+            continue;
+          }
+        }
+
+        // Only support simple LV and Constant Ptr case.
+        // Other case just go normal path.
+        InitConstants.clear();
+        break;
+      }
+      if (InitConstants.empty())
+        staticConstGlobalInitListMap.erase(GV);
+      else
+        staticConstGlobalCtorMap[GV] = CGF.CurFn;
+    }
+  }
+  return result;
 }
 
 Value *CGMSHLSLRuntime::EmitHLSLInitListExpr(CodeGenFunction &CGF, InitListExpr *E,

+ 22 - 0
tools/clang/test/CodeGenHLSL/static_const_global.hlsl

@@ -0,0 +1,22 @@
+// RUN: %dxc -T lib_6_1 %s | FileCheck %s
+
+// Make sure ST is removed
+// CHECK-NOT: @ST
+
+cbuffer A {
+  float a;
+  int b;
+}
+
+const static struct {
+  float a;
+  int b;
+}  ST = { a, b };
+
+float4 test() {
+  return ST.a + ST.b;
+}
+
+float test2() {
+  return ST.a - ST.b;
+}

+ 19 - 0
tools/clang/test/CodeGenHLSL/static_const_global2.hlsl

@@ -0,0 +1,19 @@
+// RUN: %dxc -T ps_6_1 -E main -fcgl %s | FileCheck %s
+
+// Make sure ST only used once for decl.
+// CHECK: @ST
+// CHECK-NOT: @ST
+
+cbuffer A {
+  float a;
+  int b;
+}
+
+const static struct {
+  float a;
+  int b;
+}  ST = { a, b };
+
+float4 main() : SV_TARGET  {
+  return ST.a + ST.b;
+}

+ 10 - 0
tools/clang/unittests/HLSL/CompilerTest.cpp

@@ -674,6 +674,8 @@ public:
   TEST_METHOD(CodeGenSrv_Ms_Load2)
   TEST_METHOD(CodeGenSrv_Typed_Load1)
   TEST_METHOD(CodeGenSrv_Typed_Load2)
+  TEST_METHOD(CodeGenStaticConstGlobal)
+  TEST_METHOD(CodeGenStaticConstGlobal2)
   TEST_METHOD(CodeGenStaticGlobals)
   TEST_METHOD(CodeGenStaticGlobals2)
   TEST_METHOD(CodeGenStaticGlobals3)
@@ -3696,6 +3698,14 @@ TEST_F(CompilerTest, CodeGenSrv_Typed_Load2) {
   CodeGenTest(L"..\\CodeGenHLSL\\srv_typed_load2.hlsl");
 }
 
+TEST_F(CompilerTest, CodeGenStaticConstGlobal) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\static_const_global.hlsl");
+}
+
+TEST_F(CompilerTest, CodeGenStaticConstGlobal2) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\static_const_global2.hlsl");
+}
+
 TEST_F(CompilerTest, CodeGenStaticGlobals) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\staticGlobals.hlsl");
 }