Browse Source

Save imm initializer when possible. (#145)

* Save imm initializer when possible.
* Remove UpdateHLSLIncompleteArrayType which already done in SemaInit.
* Take care parent record in AddMissingCastOpsInInitList and HLSLExternalSource::GetNumBasicElements.
* Support nest init list in CaculateInitListArraySizeForHLSL.
Xiang Li 8 years ago
parent
commit
094dfa1746

+ 46 - 9
lib/HLSL/HLMatrixLowerPass.cpp

@@ -1416,7 +1416,7 @@ void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobal(
     Value *matGlobal, ArrayRef<Value *> vecGlobals,
     CallInst *matLdStInst) {
   // No dynamic indexing on matrix, flatten matrix to scalars.
-
+  // Internal global matrix use row major follow the initializer.
   Type *matType = matGlobal->getType()->getPointerElementType();
   unsigned col, row;
   HLMatrixLower::GetMatrixInfo(matType, col, row);
@@ -2211,6 +2211,22 @@ static bool OnlyUsedByMatrixLdSt(Value *V) {
   return onlyLdSt;
 }
 
+static Constant *LowerMatrixArrayConst(Constant *MA, ArrayType *ResultTy) {
+  if (ArrayType *AT = dyn_cast<ArrayType>(MA->getType())) {
+    std::vector<Constant *> Elts;
+    ArrayType *EltResultTy = cast<ArrayType>(ResultTy->getElementType());
+    for (unsigned i = 0; i < AT->getNumElements(); i++) {
+      Constant *Elt =
+          LowerMatrixArrayConst(MA->getAggregateElement(i), EltResultTy);
+      Elts.emplace_back(Elt);
+    }
+    return ConstantArray::get(ResultTy, Elts);
+  } else {
+    // Get float[row][col] from the struct.
+    return MA->getAggregateElement((unsigned)0);
+  }
+}
+
 void HLMatrixLowerPass::runOnGlobalMatrixArray(GlobalVariable *GV) {
   // Lower to array of vector array like float[row][col].
   // DynamicIndexingVectorToArray will change it to scalar array.
@@ -2230,10 +2246,11 @@ void HLMatrixLowerPass::runOnGlobalMatrixArray(GlobalVariable *GV) {
     Ty = ArrayType::get(Ty, *arraySize);
 
   Type *VecArrayTy = Ty;
-
-  // Matrix will use store to initialize.
-  // So set init val to undef.
-  Constant *InitVal = UndefValue::get(VecArrayTy);
+  Constant *OldInitVal = GV->getInitializer();
+  Constant *InitVal =
+      isa<UndefValue>(OldInitVal)
+          ? UndefValue::get(VecArrayTy)
+          : LowerMatrixArrayConst(OldInitVal, cast<ArrayType>(VecArrayTy));
 
   bool isConst = GV->isConstant();
   GlobalVariable::ThreadLocalMode TLMode = GV->getThreadLocalMode();
@@ -2285,6 +2302,24 @@ void HLMatrixLowerPass::runOnGlobalMatrixArray(GlobalVariable *GV) {
   GV->eraseFromParent();
 }
 
+static void FlattenMatConst(Constant *M, std::vector<Constant *> &Elts) {
+  unsigned row, col;
+  Type *EltTy = HLMatrixLower::GetMatrixInfo(M->getType(), col, row);
+  if (isa<UndefValue>(M)) {
+    Constant *Elt = UndefValue::get(EltTy);
+    for (unsigned i=0;i<col*row;i++)
+      Elts.emplace_back(Elt);
+  } else {
+    M = M->getAggregateElement((unsigned)0);
+    for (unsigned r = 0; r < row; r++) {
+      Constant *R = M->getAggregateElement(r);
+      for (unsigned c = 0; c < col; c++) {
+        Elts.emplace_back(R->getAggregateElement(c));
+      }
+    }
+  }
+}
+
 void HLMatrixLowerPass::runOnGlobal(GlobalVariable *GV) {
   if (HLMatrixLower::IsMatrixArrayPointer(GV->getType())) {
     runOnGlobalMatrixArray(GV);
@@ -2303,13 +2338,13 @@ void HLMatrixLowerPass::runOnGlobal(GlobalVariable *GV) {
   Module *M = GV->getParent();
   const DataLayout &DL = M->getDataLayout();
 
+  std::vector<Constant *> Elts;
+  FlattenMatConst(GV->getInitializer(), Elts);
+
   if (onlyLdSt) {
     Type *EltTy = vecTy->getVectorElementType();
     unsigned vecSize = vecTy->getVectorNumElements();
     std::vector<Value *> vecGlobals(vecSize);
-    // Matrix will use store to initialize.
-    // So set init val to undef.
-    Constant *InitVal = UndefValue::get(EltTy);
 
     GlobalVariable::ThreadLocalMode TLMode = GV->getThreadLocalMode();
     unsigned AddressSpace = GV->getType()->getAddressSpace();
@@ -2318,6 +2353,7 @@ void HLMatrixLowerPass::runOnGlobal(GlobalVariable *GV) {
     unsigned size = DL.getTypeAllocSizeInBits(EltTy);
     unsigned align = DL.getPrefTypeAlignment(EltTy);
     for (int i = 0, e = vecSize; i != e; ++i) {
+      Constant *InitVal = Elts[i];
       GlobalVariable *EltGV = new llvm::GlobalVariable(
           *M, EltTy, /*IsConstant*/ isConst, linkage,
           /*InitVal*/ InitVal, GV->getName() + "." + Twine(i),
@@ -2344,9 +2380,10 @@ void HLMatrixLowerPass::runOnGlobal(GlobalVariable *GV) {
   else {
     // lower to array of scalar here.
     ArrayType *AT = ArrayType::get(vecTy->getVectorElementType(), vecTy->getVectorNumElements());
+    Constant *InitVal = ConstantArray::get(AT, Elts);
     GlobalVariable *arrayMat = new llvm::GlobalVariable(
       *M, AT, /*IsConstant*/ false, llvm::GlobalValue::InternalLinkage,
-      /*InitVal*/ UndefValue::get(AT), GV->getName());
+      /*InitVal*/ InitVal, GV->getName());
     // Add debug info.
     if (m_HasDbgInfo) {
       DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();

+ 80 - 17
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -3211,6 +3211,44 @@ bool SROA_Helper::DoScalarReplacement(Value *V, std::vector<Value *> &Elts,
   return true;
 }
 
+static Constant *GetEltInit(Type *Ty, Constant *Init, unsigned idx,
+                            Type *EltTy) {
+  if (isa<UndefValue>(Init))
+    return UndefValue::get(EltTy);
+
+  if (StructType *ST = dyn_cast<StructType>(Ty)) {
+    return Init->getAggregateElement(idx);
+  } else if (VectorType *VT = dyn_cast<VectorType>(Ty)) {
+    return Init->getAggregateElement(idx);
+  } else {
+    ArrayType *AT = cast<ArrayType>(Ty);
+    ArrayType *EltArrayTy = cast<ArrayType>(EltTy);
+    std::vector<Constant *> Elts;
+    if (!AT->getElementType()->isArrayTy()) {
+      for (unsigned i = 0; i < AT->getNumElements(); i++) {
+        // Get Array[i]
+        Constant *InitArrayElt = Init->getAggregateElement(i);
+        // Get Array[i].idx
+        InitArrayElt = InitArrayElt->getAggregateElement(idx);
+        Elts.emplace_back(InitArrayElt);
+      }
+      return ConstantArray::get(EltArrayTy, Elts);
+    } else {
+      Type *EltTy = AT->getElementType();
+      ArrayType *NestEltArrayTy = cast<ArrayType>(EltArrayTy->getElementType());
+      // Nested array.
+      for (unsigned i = 0; i < AT->getNumElements(); i++) {
+        // Get Array[i]
+        Constant *InitArrayElt = Init->getAggregateElement(i);
+        // Get Array[i].idx
+        InitArrayElt = GetEltInit(EltTy, InitArrayElt, idx, NestEltArrayTy);
+        Elts.emplace_back(InitArrayElt);
+      }
+      return ConstantArray::get(EltArrayTy, Elts);
+    }
+  }
+}
+
 /// DoScalarReplacement - Split V into AllocaInsts with Builder and save the new AllocaInsts into Elts.
 /// Then do SROA on V.
 bool SROA_Helper::DoScalarReplacement(GlobalVariable *GV, std::vector<Value *> &Elts,
@@ -3252,7 +3290,7 @@ bool SROA_Helper::DoScalarReplacement(GlobalVariable *GV, std::vector<Value *> &
     Elts.reserve(numTypes);
     //DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ST);
     for (int i = 0, e = numTypes; i != e; ++i) {
-      Constant *EltInit = cast<Constant>(Builder.CreateExtractValue(Init, i));
+      Constant *EltInit = GetEltInit(Ty, Init, i, ST->getElementType(i));
       GlobalVariable *EltGV = new llvm::GlobalVariable(
           *M, ST->getContainedType(i), /*IsConstant*/ isConst, linkage,
           /*InitVal*/ EltInit, GV->getName() + "." + Twine(i),
@@ -3271,7 +3309,7 @@ bool SROA_Helper::DoScalarReplacement(GlobalVariable *GV, std::vector<Value *> &
     Type *EltTy = VT->getElementType();
     //DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ST);
     for (int i = 0, e = numElts; i != e; ++i) {
-      Constant *EltInit = cast<Constant>(Builder.CreateExtractElement(Init, i));
+      Constant *EltInit = GetEltInit(Ty, Init, i, EltTy);
       GlobalVariable *EltGV = new llvm::GlobalVariable(
           *M, EltTy, /*IsConstant*/ isConst, linkage,
           /*InitVal*/ EltInit, GV->getName() + "." + Twine(i),
@@ -3312,8 +3350,7 @@ bool SROA_Helper::DoScalarReplacement(GlobalVariable *GV, std::vector<Value *> &
       for (int i = 0, e = numTypes; i != e; ++i) {
         Type *EltTy =
             CreateNestArrayTy(ElST->getContainedType(i), nestArrayTys);
-        // Don't need InitVal, struct type will use store to init.
-        Constant *EltInit = UndefValue::get(EltTy);
+        Constant *EltInit = GetEltInit(Ty, Init, i, EltTy);
         GlobalVariable *EltGV = new llvm::GlobalVariable(
             *M, EltTy, /*IsConstant*/ isConst, linkage,
             /*InitVal*/ EltInit, GV->getName() + "." + Twine(i),
@@ -3339,8 +3376,7 @@ bool SROA_Helper::DoScalarReplacement(GlobalVariable *GV, std::vector<Value *> &
           CreateNestArrayTy(ElVT->getElementType(), nestArrayTys);
 
       for (int i = 0, e = ElVT->getNumElements(); i != e; ++i) {
-        // Don't need InitVal, struct type will use store to init.
-        Constant *EltInit = UndefValue::get(scalarArrayTy);
+        Constant *EltInit = GetEltInit(Ty, Init, i, scalarArrayTy);
         GlobalVariable *EltGV = new llvm::GlobalVariable(
             *M, scalarArrayTy, /*IsConstant*/ isConst, linkage,
             /*InitVal*/ EltInit, GV->getName() + "." + Twine(i),
@@ -5193,6 +5229,26 @@ bool DynamicIndexingVectorToArray::runOnFunction(Function &F) {
   return size > 0;
 }
 
+static Constant *VectorConstToArray(Type *VecTy, Constant *C, ArrayType *ArrayTy) {
+  if (VecTy->isVectorTy()) {
+    SmallVector<Constant *, 4> Elts;
+    for (unsigned i=0;i<VecTy->getVectorNumElements();i++) {
+      Elts.emplace_back(C->getAggregateElement(i));
+    }
+    return ConstantArray::get(ArrayTy, Elts);
+  } else {
+    ArrayType *AT = cast<ArrayType>(VecTy);
+    Type *EltTy = AT->getElementType();
+    ArrayType *EltArrayTy = cast<ArrayType>(ArrayTy->getElementType());
+    SmallVector<Constant *, 4> Elts;
+    for (unsigned i=0;i<AT->getNumElements();i++) {
+      Constant *Elt = VectorConstToArray(EltTy, C->getAggregateElement(i), EltArrayTy);
+      Elts.emplace_back(Elt);
+    }
+    return ConstantArray::get(ArrayTy, Elts);
+  }
+}
+
 void DynamicIndexingVectorToArray::runOnInternalGlobal(GlobalVariable *GV,
                                                        HLModule *HLM) {
   Type *Ty = GV->getType()->getPointerElementType();
@@ -5224,15 +5280,7 @@ void DynamicIndexingVectorToArray::runOnInternalGlobal(GlobalVariable *GV,
       InitVal = ConstantAggregateZero::get(AT);
     else if (!isa<UndefValue>(vecInitVal)) {
       // build arrayInitVal.
-      // Only vector initializer could reach here.
-      // Complex case will use store to init.
-      DXASSERT_NOMSG(vecInitVal->getType()->isVectorTy());
-      ConstantDataVector *CDV = cast<ConstantDataVector>(vecInitVal);
-      unsigned vecSize = CDV->getType()->getVectorNumElements();
-      std::vector<Constant *> vals;
-      for (unsigned i = 0; i < vecSize; i++)
-        vals.emplace_back(CDV->getAggregateElement(i));
-      InitVal = ConstantArray::get(AT, vals);
+      InitVal = VectorConstToArray(vecInitVal->getType(), vecInitVal, AT);
     }
   }
 
@@ -5421,6 +5469,18 @@ void MultiDimArrayToOneDimArray::flattenAlloca(AllocaInst *AI) {
   AI->eraseFromParent();
 }
 
+static void FlattenMultiDimConstArray(Constant *V,
+                                      std::vector<Constant *> &Elts) {
+  if (!V->getType()->isArrayTy()) {
+    Elts.emplace_back(V);
+  } else {
+    ArrayType *AT = cast<ArrayType>(V->getType());
+    for (unsigned i = 0; i < AT->getNumElements(); i++) {
+      FlattenMultiDimConstArray(V->getAggregateElement(i), Elts);
+    }
+  }
+}
+
 void MultiDimArrayToOneDimArray::flattenGlobal(GlobalVariable *GV, DxilModule *DM) {
   Type *Ty = GV->getType()->getElementType();
 
@@ -5443,8 +5503,11 @@ void MultiDimArrayToOneDimArray::flattenGlobal(GlobalVariable *GV, DxilModule *D
       InitVal = ConstantAggregateZero::get(AT);
     else if (isa<UndefValue>(InitVal))
       InitVal = UndefValue::get(AT);
-    else
-      DXASSERT(0, "invalid initializer");
+    else {
+      std::vector<Constant *> Elts;
+      FlattenMultiDimConstArray(InitVal, Elts);
+      InitVal = ConstantArray::get(AT, Elts);
+    }
   } else {
     InitVal = UndefValue::get(AT);
   }

+ 0 - 12
tools/clang/lib/CodeGen/CGDecl.cpp

@@ -343,12 +343,6 @@ void CodeGenFunction::EmitStaticVarDecl(const VarDecl &D,
   llvm::Value *&DMEntry = LocalDeclMap[&D];
   assert(!DMEntry && "Decl already exists in localdeclmap!");
 
-  // HLSL Change Begins.
-  if (D.getType()->isIncompleteArrayType() && getLangOpts().HLSL) {
-    CGM.getHLSLRuntime().UpdateHLSLIncompleteArrayType(const_cast<VarDecl&>(D));
-  }
-  // HLSL Change Ends.
-
   // Check to see if we already have a global variable for this
   // declaration.  This can happen when double-emitting function
   // bodies, e.g. with complete and base constructors.
@@ -911,12 +905,6 @@ CodeGenFunction::EmitAutoVarAlloca(const VarDecl &D) {
   if (Ty->isVariablyModifiedType())
     EmitVariablyModifiedType(Ty);
 
-  // HLSL Change Begins.
-  if (Ty->isIncompleteArrayType() && getLangOpts().HLSL) {
-    Ty = CGM.getHLSLRuntime().UpdateHLSLIncompleteArrayType(const_cast<VarDecl&>(D));
-  }
-  // HLSL Change Ends.
-
   llvm::Value *DeclPtr;
   if (Ty->isConstantSizeType()) {
     bool NRVO = getLangOpts().ElideConstructors &&

+ 0 - 6
tools/clang/lib/CodeGen/CGExpr.cpp

@@ -1931,12 +1931,6 @@ static LValue EmitGlobalVarDeclLValue(CodeGenFunction &CGF,
       CGF.CGM.getCXXABI().usesThreadWrapperFunction())
     return CGF.CGM.getCXXABI().EmitThreadLocalVarDeclLValue(CGF, VD, T);
 
-  // HLSL Change Begins.
-  if (VD->getType()->isIncompleteArrayType() && CGF.getLangOpts().HLSL) {
-    T = CGF.CGM.getHLSLRuntime().UpdateHLSLIncompleteArrayType(const_cast<VarDecl&>(*VD));
-  }
-  // HLSL Change Ends.
-
   llvm::Value *V = CGF.CGM.GetAddrOfGlobalVar(VD);
   llvm::Type *RealVarTy = CGF.getTypes().ConvertTypeForMem(VD->getType());
   V = EmitBitCastOfLValueToProperType(CGF, V, RealVarTy);

+ 10 - 9
tools/clang/lib/CodeGen/CGExprConstant.cpp

@@ -25,6 +25,7 @@
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/GlobalVariable.h"
+#include "CGHLSLRuntime.h"   // HLSL Change
 using namespace clang;
 using namespace CodeGen;
 
@@ -744,8 +745,11 @@ public:
     // HLSL Change Begins.
     case CK_FlatConversion:
       return nullptr;
-    case CK_HLSLVectorSplat:
-      return nullptr;
+    case CK_HLSLVectorSplat: {
+      unsigned vecSize = hlsl::GetHLSLVecSize(E->getType());
+      std::vector<llvm::Constant*> Elts(vecSize, C);
+      return llvm::ConstantVector::get(Elts);
+    }
     // HLSL Change Ends.
     }
     llvm_unreachable("Invalid CastKind");
@@ -833,15 +837,12 @@ public:
   }
 
   llvm::Constant *VisitInitListExpr(InitListExpr *ILE) {
-    if (ILE->getType()->isArrayType())
-      return EmitArrayInitialization(ILE);
-
     // HLSL Change Begins.
-    if (hlsl::IsHLSLVecType(ILE->getType()))
-      return CGM.EmitConstantExpr(ILE, ILE->getType(), CGF);
-    if (hlsl::IsHLSLMatType(ILE->getType()))
-      return nullptr;
+    if (CGM.getLangOpts().HLSL)
+      return CGM.getHLSLRuntime().EmitHLSLConstInitListExpr(CGM, ILE);
     // HLSL Change Ends.
+    if (ILE->getType()->isArrayType())
+      return EmitArrayInitialization(ILE);
 
     if (ILE->getType()->isRecordType())
       return EmitRecordInitialization(ILE);

+ 194 - 41
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -197,7 +197,7 @@ public:
   void addResource(Decl *D) override;
   void FinishCodeGen() override;
   Value *EmitHLSLInitListExpr(CodeGenFunction &CGF, InitListExpr *E, Value *DestPtr) override;
-  QualType UpdateHLSLIncompleteArrayType(VarDecl &D) override;
+  Constant *EmitHLSLConstInitListExpr(CodeGenModule &CGM, InitListExpr *E) override;
 
   RValue EmitHLSLBuiltinCallExpr(CodeGenFunction &CGF, const FunctionDecl *FD,
                                  const CallExpr *E,
@@ -4233,6 +4233,19 @@ static void AddMissingCastOpsInInitList(SmallVector<Value *, 4> &elts, SmallVect
       if (!RT)
         RT = Ty->getAs<RecordType>();
       RecordDecl *RD = RT->getDecl();
+      // Take care base.
+      if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+        if (CXXRD->getNumBases()) {
+          for (const auto &I : CXXRD->bases()) {
+            const CXXRecordDecl *BaseDecl = cast<CXXRecordDecl>(
+                I.getType()->castAs<RecordType>()->getDecl());
+            if (BaseDecl->field_empty())
+              continue;
+            QualType parentTy = QualType(BaseDecl->getTypeForDecl(), 0);
+            AddMissingCastOpsInInitList(elts, eltTys, idx, parentTy, CGF);
+          }
+        }
+      }
       for (FieldDecl *field : RD->fields())
         AddMissingCastOpsInInitList(elts, eltTys, idx, field->getType(), CGF);
     }
@@ -4333,46 +4346,6 @@ void CGMSHLSLRuntime::ScanInitList(CodeGenFunction &CGF, InitListExpr *E,
   }
 }
 
-unsigned CGMSHLSLRuntime::ScanInitList(InitListExpr *E) {
-  unsigned NumInitElements = E->getNumInits();
-  unsigned size = 0;
-  for (unsigned i = 0; i != NumInitElements; ++i) {
-    Expr *init = E->getInit(i);
-    QualType iType = init->getType();
-    if (InitListExpr *initList = dyn_cast<InitListExpr>(init)) {
-      size += ScanInitList(initList);
-    } else if (CodeGenFunction::hasScalarEvaluationKind(iType)) {
-      size += GetElementCount(iType);
-    } else {
-      DXASSERT(0, "not support yet");
-    }
-
-  }
-  return size;
-}
-
-QualType CGMSHLSLRuntime::UpdateHLSLIncompleteArrayType(VarDecl &D) {
-  if (!D.hasInit())
-    return D.getType();
-
-  InitListExpr *E = dyn_cast<InitListExpr>(D.getInit());
-  if (!E)
-    return D.getType();
-
-  unsigned arrayEltCount = ScanInitList(E);
-
-  QualType ResultTy = E->getType();
-
-  QualType EltTy = QualType(ResultTy->getArrayElementTypeNoTypeQual(), 0);
-  unsigned eltCount = GetElementCount(EltTy);
-  llvm::APInt ArySize(32, arrayEltCount / eltCount);
-  QualType ArrayTy = CGM.getContext().getConstantArrayType(
-      EltTy, ArySize, clang::ArrayType::Normal, 0);
-  D.setType(ArrayTy);
-  E->setType(ArrayTy);
-  return ArrayTy;
-}
-
 Value *CGMSHLSLRuntime::EmitHLSLInitListExpr(CodeGenFunction &CGF, InitListExpr *E,
       // The destPtr when emiting aggregate init, for normal case, it will be null.
       Value *DestPtr) {
@@ -4412,6 +4385,186 @@ Value *CGMSHLSLRuntime::EmitHLSLInitListExpr(CodeGenFunction &CGF, InitListExpr
   }
 }
 
+static void FlatConstToList(Constant *C,
+                            SmallVector<Constant *, 4> &EltValList) {
+  llvm::Type *Ty = C->getType();
+  if (llvm::VectorType *VT = dyn_cast<llvm::VectorType>(Ty)) {
+    for (unsigned i = 0; i < VT->getNumElements(); i++) {
+      FlatConstToList(C->getAggregateElement(i), EltValList);
+    }
+  } else if (llvm::ArrayType *AT = dyn_cast<llvm::ArrayType>(Ty)) {
+    for (unsigned i = 0; i < AT->getNumElements(); i++) {
+      FlatConstToList(C->getAggregateElement(i), EltValList);
+    }
+  } else if (llvm::StructType *ST = dyn_cast<llvm::StructType>(Ty)) {
+    for (unsigned i = 0; i < ST->getNumElements(); i++) {
+      FlatConstToList(C->getAggregateElement(i), EltValList);
+    }
+  } else {
+    EltValList.emplace_back(C);
+  }
+}
+
+static bool ScanConstInitList(CodeGenModule &CGM, InitListExpr *E,
+                              SmallVector<Constant *, 4> &EltValList) {
+  unsigned NumInitElements = E->getNumInits();
+  for (unsigned i = 0; i != NumInitElements; ++i) {
+    Expr *init = E->getInit(i);
+    QualType iType = init->getType();
+    if (InitListExpr *initList = dyn_cast<InitListExpr>(init)) {
+      if (!ScanConstInitList(CGM, initList, EltValList))
+        return false;
+    } else if (DeclRefExpr *ref = dyn_cast<DeclRefExpr>(init)) {
+      if (VarDecl *D = dyn_cast<VarDecl>(ref->getDecl())) {
+        if (Constant *initVal = CGM.EmitConstantInit(*D)) {
+          FlatConstToList(initVal, EltValList);
+        } else {
+          return false;
+        }
+      } else {
+        return false;
+      }
+    } else if (hlsl::IsHLSLMatType(iType)) {
+      return false;
+    } else if (CodeGenFunction::hasScalarEvaluationKind(iType)) {
+      if (Constant *initVal = CGM.EmitConstantExpr(init, iType)) {
+        FlatConstToList(initVal, EltValList);
+      } else {
+        return false;
+      }
+    } else {
+      return false;
+    }
+  }
+  return true;
+}
+
+static Constant *BuildConstInitializer(QualType Type, unsigned &offset,
+                                       SmallVector<Constant *, 4> &EltValList,
+                                       CodeGenTypes &Types);
+
+static Constant *BuildConstVector(llvm::VectorType *VT, unsigned &offset,
+                                  SmallVector<Constant *, 4> &EltValList,
+                                  QualType Type, CodeGenTypes &Types) {
+  SmallVector<Constant *, 4> Elts;
+  QualType EltTy = hlsl::GetHLSLVecElementType(Type);
+  for (unsigned i = 0; i < VT->getNumElements(); i++) {
+    Elts.emplace_back(BuildConstInitializer(EltTy, offset, EltValList, Types));
+  }
+  return llvm::ConstantVector::get(Elts);
+}
+
+static Constant *BuildConstMatrix(llvm::Type *Ty, unsigned &offset,
+                                  SmallVector<Constant *, 4> &EltValList,
+                                  QualType Type, CodeGenTypes &Types) {
+  QualType EltTy = hlsl::GetHLSLMatElementType(Type);
+  unsigned col, row;
+  HLMatrixLower::GetMatrixInfo(Ty, col, row);
+  llvm::ArrayType *AT = cast<llvm::ArrayType>(Ty->getStructElementType(0));
+  // Matrix initializer is row major.
+  // The type is vector<element, col>[row].
+  SmallVector<Constant *, 4> rows;
+  for (unsigned r = 0; r < row; r++) {
+    SmallVector<Constant *, 4> cols;
+    for (unsigned c = 0; c < col; c++) {
+      cols.emplace_back(
+          BuildConstInitializer(EltTy, offset, EltValList, Types));
+    }
+    rows.emplace_back(llvm::ConstantVector::get(cols));
+  }
+  Constant *mat = llvm::ConstantArray::get(AT, rows);
+  return llvm::ConstantStruct::get(cast<llvm::StructType>(Ty), mat);
+}
+
+static Constant *BuildConstArray(llvm::ArrayType *AT, unsigned &offset,
+                                 SmallVector<Constant *, 4> &EltValList,
+                                 QualType Type, CodeGenTypes &Types) {
+  SmallVector<Constant *, 4> Elts;
+  QualType EltType = QualType(Type->getArrayElementTypeNoTypeQual(), 0);
+  for (unsigned i = 0; i < AT->getNumElements(); i++) {
+    Elts.emplace_back(
+        BuildConstInitializer(EltType, offset, EltValList, Types));
+  }
+  return llvm::ConstantArray::get(AT, Elts);
+}
+
+static Constant *BuildConstStruct(llvm::StructType *ST, unsigned &offset,
+                                  SmallVector<Constant *, 4> &EltValList,
+                                  QualType Type, CodeGenTypes &Types) {
+  SmallVector<Constant *, 4> Elts;
+
+  const RecordType *RT = Type->getAsStructureType();
+  if (!RT)
+    RT = Type->getAs<RecordType>();
+  const RecordDecl *RD = RT->getDecl();
+
+  if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+    if (CXXRD->getNumBases()) {
+      // Add base as field.
+      for (const auto &I : CXXRD->bases()) {
+        const CXXRecordDecl *BaseDecl =
+            cast<CXXRecordDecl>(I.getType()->castAs<RecordType>()->getDecl());
+        // Skip empty struct.
+        if (BaseDecl->field_empty())
+          continue;
+
+        // Add base as a whole constant. Not as element.
+        Elts.emplace_back(
+            BuildConstInitializer(I.getType(), offset, EltValList, Types));
+      }
+    }
+  }
+
+  for (auto fieldIter = RD->field_begin(), fieldEnd = RD->field_end();
+       fieldIter != fieldEnd; ++fieldIter) {
+    Elts.emplace_back(
+        BuildConstInitializer(fieldIter->getType(), offset, EltValList, Types));
+  }
+
+  return llvm::ConstantStruct::get(ST, Elts);
+}
+
+static Constant *BuildConstInitializer(QualType Type, unsigned &offset,
+                                       SmallVector<Constant *, 4> &EltValList,
+                                       CodeGenTypes &Types) {
+  llvm::Type *Ty = Types.ConvertType(Type);
+  if (llvm::VectorType *VT = dyn_cast<llvm::VectorType>(Ty)) {
+    return BuildConstVector(VT, offset, EltValList, Type, Types);
+  } else if (llvm::ArrayType *AT = dyn_cast<llvm::ArrayType>(Ty)) {
+    return BuildConstArray(AT, offset, EltValList, Type, Types);
+  } else if (HLMatrixLower::IsMatrixType(Ty)) {
+    return BuildConstMatrix(Ty, offset, EltValList, Type, Types);
+  } else if (StructType *ST = dyn_cast<llvm::StructType>(Ty)) {
+    return BuildConstStruct(ST, offset, EltValList, Type, Types);
+  } else {
+    // Scalar basic types.
+    Constant *Val = EltValList[offset++];
+    if (Val->getType() == Ty) {
+      return Val;
+    } else {
+      IRBuilder<> Builder(Ty->getContext());
+      // Don't cast int to bool. bool only for scalar.
+      if (Ty == Builder.getInt1Ty() && Val->getType() == Builder.getInt32Ty())
+        return Val;
+      Instruction::CastOps castOp =
+          static_cast<Instruction::CastOps>(HLModule::FindCastOp(
+              IsUnsigned(Type), IsUnsigned(Type), Val->getType(), Ty));
+      return cast<Constant>(Builder.CreateCast(castOp, Val, Ty));
+    }
+  }
+}
+
+Constant *CGMSHLSLRuntime::EmitHLSLConstInitListExpr(CodeGenModule &CGM,
+                                                     InitListExpr *E) {
+  SmallVector<Constant *, 4> EltValList;
+  if (!ScanConstInitList(CGM, E, EltValList))
+    return nullptr;
+
+  QualType Type = E->getType();
+  unsigned offset = 0;
+  return BuildConstInitializer(Type, offset, EltValList, CGM.getTypes());
+}
+
 Value *CGMSHLSLRuntime::EmitHLSLMatrixOperationCall(
     CodeGenFunction &CGF, const clang::Expr *E, llvm::Type *RetType,
     ArrayRef<Value *> paramList) {

+ 2 - 2
tools/clang/lib/CodeGen/CGHLSLRuntime.h

@@ -17,6 +17,7 @@ namespace llvm {
 class Function;
 template <typename T, unsigned N> class SmallVector;
 class Value;
+class Constant;
 class TerminatorInst;
 class Type;
 template <typename T> class ArrayRef;
@@ -61,8 +62,7 @@ public:
   virtual llvm::Value *EmitHLSLInitListExpr(CodeGenFunction &CGF, InitListExpr *E,
       // The destPtr when emiting aggregate init, for normal case, it will be null.
       llvm::Value *DestPtr) = 0;
-
-  virtual clang::QualType UpdateHLSLIncompleteArrayType(VarDecl &D) = 0;
+  virtual llvm::Constant *EmitHLSLConstInitListExpr(CodeGenModule &CGM, InitListExpr *E) = 0;
 
   virtual void EmitHLSLOutParamConversionInit(
       CodeGenFunction &CGF, const FunctionDecl *FD, const CallExpr *E,

+ 0 - 6
tools/clang/lib/CodeGen/CodeGenModule.cpp

@@ -2037,12 +2037,6 @@ void CodeGenModule::EmitGlobalVarDefinition(const VarDecl *D) {
   const VarDecl *InitDecl;
   const Expr *InitExpr = D->getAnyInitializer(InitDecl);
 
-  // HLSL Change Begins.
-  if (D->getType()->isIncompleteArrayType() && getLangOpts().HLSL) {
-    getHLSLRuntime().UpdateHLSLIncompleteArrayType(const_cast<VarDecl&>(*D));
-  }
-  // HLSL Change Ends.
-
   if (!InitExpr) {
     // This is a tentative definition; tentative definitions are
     // implicitly initialized with { 0 }.

+ 32 - 8
tools/clang/lib/Sema/SemaHLSL.cpp

@@ -5209,8 +5209,22 @@ unsigned HLSLExternalSource::GetNumBasicElements(QualType anyType) {
     // TODO: consider caching this value for perf
     unsigned total = 0;
     const RecordType *recordType = anyType->getAs<RecordType>();
-    RecordDecl::field_iterator fi = recordType->getDecl()->field_begin();
-    RecordDecl::field_iterator fend = recordType->getDecl()->field_end();
+    RecordDecl * RD = recordType->getDecl();
+    // Take care base.
+    if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+      if (CXXRD->getNumBases()) {
+        for (const auto &I : CXXRD->bases()) {
+          const CXXRecordDecl *BaseDecl =
+              cast<CXXRecordDecl>(I.getType()->castAs<RecordType>()->getDecl());
+          if (BaseDecl->field_empty())
+            continue;
+          QualType parentTy = QualType(BaseDecl->getTypeForDecl(), 0);
+          total += GetNumBasicElements(parentTy);
+        }
+      }
+    }
+    RecordDecl::field_iterator fi = RD->field_begin();
+    RecordDecl::field_iterator fend = RD->field_end();
     while (fi != fend) {
       total += GetNumBasicElements(fi->getType());
       ++fi;
@@ -8634,19 +8648,29 @@ void hlsl::InitializeInitSequenceForHLSL(Sema *self,
     ->InitializeInitSequenceForHLSL(Entity, Kind, Args, TopLevelOfInitList, initSequence);
 }
 
+static unsigned CaculateInitListSize(HLSLExternalSource *hlslSource,
+                                     const clang::InitListExpr *InitList) {
+  unsigned totalSize = 0;
+  for (unsigned i = 0; i < InitList->getNumInits(); i++) {
+    const clang::Expr *EltInit = InitList->getInit(i);
+    QualType EltInitTy = EltInit->getType();
+    if (const InitListExpr *EltInitList = dyn_cast<InitListExpr>(EltInit)) {
+      totalSize += CaculateInitListSize(hlslSource, EltInitList);
+    } else {
+      totalSize += hlslSource->GetNumBasicElements(EltInitTy);
+    }
+  }
+  return totalSize;
+}
+
 unsigned hlsl::CaculateInitListArraySizeForHLSL(
   _In_ clang::Sema* sema,
   _In_ const clang::InitListExpr *InitList,
   _In_ const clang::QualType EltTy) {
   HLSLExternalSource *hlslSource = HLSLExternalSource::FromSema(sema);
-  unsigned totalSize = 0;
+  unsigned totalSize = CaculateInitListSize(hlslSource, InitList);
   unsigned eltSize = hlslSource->GetNumBasicElements(EltTy);
 
-  for (unsigned i=0;i<InitList->getNumInits();i++) {
-    const clang::Expr *EltInit = InitList->getInit(i);
-    QualType EltInitTy = EltInit->getType();
-    totalSize += hlslSource->GetNumBasicElements(EltInitTy);
-  }
   if (totalSize > 0 && (totalSize % eltSize)==0) {
     return totalSize / eltSize;
   } else {

+ 17 - 2
tools/clang/test/CodeGenHLSL/staticGlobals.hlsl

@@ -1,11 +1,24 @@
-// RUN: %dxc -E main -T ps_6_0 %s
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
 
 // TODO: create execution test.
 
+// CHECK: [3 x float] [float 5.000000e+00, float 0.000000e+00, float 0.000000e+00]
+// CHECK: [3 x float] [float 6.000000e+00, float 0.000000e+00, float 0.000000e+00]
+// CHECK: [3 x float] [float 7.000000e+00, float 0.000000e+00, float 0.000000e+00]
+// CHECK: [3 x float] [float 8.000000e+00, float 0.000000e+00, float 0.000000e+00]
+// CHECK: [16 x float] [float 1.500000e+01, float 1.600000e+01, float 1.700000e+01, float 1.800000e+01, float 1.500000e+01, float 1.600000e+01, float 1.700000e+01, float 1.800000e+01, float 1.500000e+01, float 1.600000e+01, float 1.700000e+01, float 1.800000e+01, float 1.500000e+01, float 1.600000e+01, float 1.700000e+01, float 1.800000e+01]
+// CHECK: [16 x float] [float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 2.000000e+00, float 2.000000e+00, float 2.000000e+00, float 2.000000e+00, float 3.000000e+00, float 3.000000e+00, float 3.000000e+00, float 3.000000e+00]
+// CHECK: [4 x float] [float 5.000000e+00, float 6.000000e+00, float 7.000000e+00, float 8.000000e+00]
+// CHECK: [16 x float] [float 2.500000e+01, float 2.600000e+01, float 2.700000e+01, float 2.800000e+01, float 2.500000e+01, float 2.600000e+01, float 2.700000e+01, float 2.800000e+01, float 2.500000e+01, float 2.600000e+01, float 2.700000e+01, float 2.800000e+01, float 2.500000e+01, float 2.600000e+01, float 2.700000e+01, float 2.800000e+01]
+
+
 static float4 f0 = {5,6,7,8};
 static float4 f1 = 0;
 static float4 f2 = {0,0,0,0};
+static float4 f3[] = { f0, f1, f2 };
+static float4x4 worldMatrix = { {0,0,0,0}, {1,1,1,1}, {2,2,2,2}, {3,3,3,3} };
 
+static float2x2 m1;
 static float4x4 m0 = { 15,16,17,18,
                        15,16,17,18,
                        15,16,17,18,
@@ -21,5 +34,7 @@ uint i;
 
 float4 main() : SV_TARGET {
   m2[i][1][i] = m0[i][i];
-  return f2 + f1 + f0[i] + m2[1]._m00_m01_m00_m10 + m0[i];
+  m1 = m2[i];
+  m1[0][1] = 2;
+  return f2 + f1 + f0[i] + m2[i]._m00_m01_m00_m10 + m0[i] + m1[i].y + f3[i] + worldMatrix[i];
 }

+ 48 - 0
tools/clang/test/CodeGenHLSL/staticGlobals3.hlsl

@@ -0,0 +1,48 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+
+// t3.b.x
+// CHECK: [3 x float] [float 0.000000e+00, float 2.500000e+01, float 0.000000e+00]
+// t3.b.y
+// CHECK: [3 x float] [float 0.000000e+00, float 2.600000e+01, float 0.000000e+00]
+// t3.c.x
+// CHECK: constant [3 x i32] [i32 0, i32 27, i32 0]
+// t3.c.y
+// CHECK: [3 x i32] [i32 0, i32 28, i32 0]
+// t3.a
+// CHECK: [12 x float] [float 5.000000e+00, float 6.000000e+00, float 7.000000e+00, float 8.000000e+00, float 2.500000e+01, float 2.600000e+01, float 2.700000e+01, float 2.800000e+01, float 3.000000e+00, float 3.000000e+00, float 3.000000e+00, float 3.000000e+00]
+// t3.t
+// CHECK: [24 x float] [float 2.500000e+01, float 2.600000e+01, float 2.700000e+01, float 2.800000e+01, float 2.500000e+01, float 2.600000e+01, float 2.700000e+01, float 2.800000e+01, float 3.000000e+00, float 3.000000e+00, float 3.000000e+00, float 3.000000e+00, float 5.000000e+00, float 6.000000e+00, float 7.000000e+00, float 8.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 5.000000e+00, float 6.000000e+00, float 7.000000e+00, float 8.000000e+00]
+
+
+
+static float4 f0 = {5,6,7,8};
+static float4 f1 = 3;
+static float4 f2 = {0,0,0,0};
+static float4 f3[] = { f0, f1, f2 };
+
+static float2x2 m2[4] = { 25,26,27,28,
+                       25,26,27,28,
+                       25,26,27,28,
+                       25,26,27,28 };
+
+struct T {
+   float2x2 a;
+};
+
+struct T2 : T {
+   float2 b;
+   int2   c;
+};
+
+struct T3 : T2 {
+   T t[2];
+};
+
+static T3 t3[] = { { f0, f2, m2 }, { f1, f3, f2, f0} };
+
+uint i;
+
+float4 main() : SV_TARGET {
+  return t3[i].a[i][i] + t3[i].b.xxyy + t3[i].c.xyxy + t3[i].t[i].a[i][i];
+}

+ 11 - 0
tools/clang/test/CodeGenHLSL/staticGlobals4.hlsl

@@ -0,0 +1,11 @@
+// RUN: %dxc -E main -T ps_6_0 %s
+
+static float f1 = { true };
+
+static bool f2 = { 1 };
+
+static int f3 = { false };
+
+float4 main() : SV_TARGET {
+  return f1 + f2 + f3;
+}

+ 11 - 1
tools/clang/unittests/HLSL/CompilerTest.cpp

@@ -515,6 +515,8 @@ public:
   TEST_METHOD(CodeGenSrv_Typed_Load2)
   TEST_METHOD(CodeGenStaticGlobals)
   TEST_METHOD(CodeGenStaticGlobals2)
+  TEST_METHOD(CodeGenStaticGlobals3)
+  TEST_METHOD(CodeGenStaticGlobals4)
   TEST_METHOD(CodeGenStaticResource)
   TEST_METHOD(CodeGenStaticResource2)
   TEST_METHOD(CodeGenStruct_Buf1)
@@ -2754,13 +2756,21 @@ TEST_F(CompilerTest, CodeGenSrv_Typed_Load2) {
 }
 
 TEST_F(CompilerTest, CodeGenStaticGlobals) {
-  CodeGenTest(L"..\\CodeGenHLSL\\staticGlobals.hlsl");
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\staticGlobals.hlsl");
 }
 
 TEST_F(CompilerTest, CodeGenStaticGlobals2) {
   CodeGenTest(L"..\\CodeGenHLSL\\staticGlobals2.hlsl");
 }
 
+TEST_F(CompilerTest, CodeGenStaticGlobals3) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\staticGlobals3.hlsl");
+}
+
+TEST_F(CompilerTest, CodeGenStaticGlobals4) {
+  CodeGenTest(L"..\\CodeGenHLSL\\staticGlobals4.hlsl");
+}
+
 TEST_F(CompilerTest, CodeGenStaticResource) {
   CodeGenTest(L"..\\CodeGenHLSL\\static_resource.hlsl");
 }