Parcourir la source

Set alignment for element global variable. (#2615)

Xiang Li il y a 5 ans
Parent
commit
3defb835f1

+ 29 - 4
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -90,6 +90,8 @@ public:
                                   bool hasPrecise, DxilTypeSystem &typeSys,
                                   bool hasPrecise, DxilTypeSystem &typeSys,
                                   const DataLayout &DL,
                                   const DataLayout &DL,
                                   SmallVector<Value *, 32> &DeadInsts);
                                   SmallVector<Value *, 32> &DeadInsts);
+  static unsigned GetEltAlign(unsigned ValueAlign, const DataLayout &DL,
+                              Type *EltTy, unsigned Offset);
   // Lower memcpy related to V.
   // Lower memcpy related to V.
   static bool LowerMemcpy(Value *V, DxilFieldAnnotation *annotation,
   static bool LowerMemcpy(Value *V, DxilFieldAnnotation *annotation,
                           DxilTypeSystem &typeSys, const DataLayout &DL,
                           DxilTypeSystem &typeSys, const DataLayout &DL,
@@ -3190,6 +3192,18 @@ static Constant *GetEltInit(Type *Ty, Constant *Init, unsigned idx,
   }
   }
 }
 }
 
 
+unsigned SROA_Helper::GetEltAlign(unsigned ValueAlign, const DataLayout &DL,
+                                  Type *EltTy, unsigned Offset) {
+  unsigned Alignment = ValueAlign;
+  if (ValueAlign == 0) {
+    // The minimum alignment which users can rely on when the explicit
+    // alignment is omitted or zero is that required by the ABI for this
+    // type.
+    Alignment = DL.getABITypeAlignment(EltTy);
+  }
+  return MinAlign(Alignment, Offset);
+}
+
 /// DoScalarReplacement - Split V into AllocaInsts with Builder and save the new AllocaInsts into Elts.
 /// DoScalarReplacement - Split V into AllocaInsts with Builder and save the new AllocaInsts into Elts.
 /// Then do SROA on V.
 /// Then do SROA on V.
 bool SROA_Helper::DoScalarReplacement(GlobalVariable *GV,
 bool SROA_Helper::DoScalarReplacement(GlobalVariable *GV,
@@ -3222,21 +3236,24 @@ bool SROA_Helper::DoScalarReplacement(GlobalVariable *GV,
   GlobalVariable::ThreadLocalMode TLMode = GV->getThreadLocalMode();
   GlobalVariable::ThreadLocalMode TLMode = GV->getThreadLocalMode();
   unsigned AddressSpace = GV->getType()->getAddressSpace();
   unsigned AddressSpace = GV->getType()->getAddressSpace();
   GlobalValue::LinkageTypes linkage = GV->getLinkage();
   GlobalValue::LinkageTypes linkage = GV->getLinkage();
-
+  const unsigned Alignment = GV->getAlignment();
   if (StructType *ST = dyn_cast<StructType>(Ty)) {
   if (StructType *ST = dyn_cast<StructType>(Ty)) {
     // Skip HLSL object types.
     // Skip HLSL object types.
     if (dxilutil::IsHLSLObjectType(ST))
     if (dxilutil::IsHLSLObjectType(ST))
       return false;
       return false;
     unsigned numTypes = ST->getNumContainedTypes();
     unsigned numTypes = ST->getNumContainedTypes();
     Elts.reserve(numTypes);
     Elts.reserve(numTypes);
+    unsigned Offset = 0;
     //DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ST);
     //DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ST);
     for (int i = 0, e = numTypes; i != e; ++i) {
     for (int i = 0, e = numTypes; i != e; ++i) {
-      Constant *EltInit = GetEltInit(Ty, Init, i, ST->getElementType(i));
+      Type *EltTy = ST->getElementType(i);
+      Constant *EltInit = GetEltInit(Ty, Init, i, EltTy);
       GlobalVariable *EltGV = new llvm::GlobalVariable(
       GlobalVariable *EltGV = new llvm::GlobalVariable(
           *M, ST->getContainedType(i), /*IsConstant*/ isConst, linkage,
           *M, ST->getContainedType(i), /*IsConstant*/ isConst, linkage,
           /*InitVal*/ EltInit, GV->getName() + "." + Twine(i),
           /*InitVal*/ EltInit, GV->getName() + "." + Twine(i),
           /*InsertBefore*/ nullptr, TLMode, AddressSpace);
           /*InsertBefore*/ nullptr, TLMode, AddressSpace);
-
+      EltGV->setAlignment(GetEltAlign(Alignment, DL, EltTy, Offset));
+      Offset += DL.getTypeAllocSize(EltTy);
       //DxilFieldAnnotation &FA = SA->GetFieldAnnotation(i);
       //DxilFieldAnnotation &FA = SA->GetFieldAnnotation(i);
       // TODO: set precise.
       // TODO: set precise.
       // if (hasPrecise || FA.IsPrecise())
       // if (hasPrecise || FA.IsPrecise())
@@ -3248,6 +3265,7 @@ bool SROA_Helper::DoScalarReplacement(GlobalVariable *GV,
     unsigned numElts = VT->getNumElements();
     unsigned numElts = VT->getNumElements();
     Elts.reserve(numElts);
     Elts.reserve(numElts);
     Type *EltTy = VT->getElementType();
     Type *EltTy = VT->getElementType();
+    unsigned Offset = 0;
     //DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ST);
     //DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ST);
     for (int i = 0, e = numElts; i != e; ++i) {
     for (int i = 0, e = numElts; i != e; ++i) {
       Constant *EltInit = GetEltInit(Ty, Init, i, EltTy);
       Constant *EltInit = GetEltInit(Ty, Init, i, EltTy);
@@ -3255,7 +3273,8 @@ bool SROA_Helper::DoScalarReplacement(GlobalVariable *GV,
           *M, EltTy, /*IsConstant*/ isConst, linkage,
           *M, EltTy, /*IsConstant*/ isConst, linkage,
           /*InitVal*/ EltInit, GV->getName() + "." + Twine(i),
           /*InitVal*/ EltInit, GV->getName() + "." + Twine(i),
           /*InsertBefore*/ nullptr, TLMode, AddressSpace);
           /*InsertBefore*/ nullptr, TLMode, AddressSpace);
-
+      EltGV->setAlignment(GetEltAlign(Alignment, DL, EltTy, Offset));
+      Offset += DL.getTypeAllocSize(EltTy);
       //DxilFieldAnnotation &FA = SA->GetFieldAnnotation(i);
       //DxilFieldAnnotation &FA = SA->GetFieldAnnotation(i);
       // TODO: set precise.
       // TODO: set precise.
       // if (hasPrecise || FA.IsPrecise())
       // if (hasPrecise || FA.IsPrecise())
@@ -3287,6 +3306,7 @@ bool SROA_Helper::DoScalarReplacement(GlobalVariable *GV,
       StructType *ElST = cast<StructType>(ElTy);
       StructType *ElST = cast<StructType>(ElTy);
       unsigned numTypes = ElST->getNumContainedTypes();
       unsigned numTypes = ElST->getNumContainedTypes();
       Elts.reserve(numTypes);
       Elts.reserve(numTypes);
+      unsigned Offset = 0;
       //DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ElST);
       //DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ElST);
       for (int i = 0, e = numTypes; i != e; ++i) {
       for (int i = 0, e = numTypes; i != e; ++i) {
         Type *EltTy =
         Type *EltTy =
@@ -3297,6 +3317,8 @@ bool SROA_Helper::DoScalarReplacement(GlobalVariable *GV,
             /*InitVal*/ EltInit, GV->getName() + "." + Twine(i),
             /*InitVal*/ EltInit, GV->getName() + "." + Twine(i),
             /*InsertBefore*/ nullptr, TLMode, AddressSpace);
             /*InsertBefore*/ nullptr, TLMode, AddressSpace);
 
 
+        EltGV->setAlignment(GetEltAlign(Alignment, DL, EltTy, Offset));
+        Offset += DL.getTypeAllocSize(EltTy);
         //DxilFieldAnnotation &FA = SA->GetFieldAnnotation(i);
         //DxilFieldAnnotation &FA = SA->GetFieldAnnotation(i);
         // TODO: set precise.
         // TODO: set precise.
         // if (hasPrecise || FA.IsPrecise())
         // if (hasPrecise || FA.IsPrecise())
@@ -3315,6 +3337,7 @@ bool SROA_Helper::DoScalarReplacement(GlobalVariable *GV,
 
 
       ArrayType *scalarArrayTy =
       ArrayType *scalarArrayTy =
           CreateNestArrayTy(ElVT->getElementType(), nestArrayTys);
           CreateNestArrayTy(ElVT->getElementType(), nestArrayTys);
+      unsigned Offset = 0;
 
 
       for (int i = 0, e = ElVT->getNumElements(); i != e; ++i) {
       for (int i = 0, e = ElVT->getNumElements(); i != e; ++i) {
         Constant *EltInit = GetEltInit(Ty, Init, i, scalarArrayTy);
         Constant *EltInit = GetEltInit(Ty, Init, i, scalarArrayTy);
@@ -3325,6 +3348,8 @@ bool SROA_Helper::DoScalarReplacement(GlobalVariable *GV,
         // TODO: set precise.
         // TODO: set precise.
         // if (hasPrecise)
         // if (hasPrecise)
         //  HLModule::MarkPreciseAttributeWithMetadata(NA);
         //  HLModule::MarkPreciseAttributeWithMetadata(NA);
+        EltGV->setAlignment(GetEltAlign(Alignment, DL, scalarArrayTy, Offset));
+        Offset += DL.getTypeAllocSize(scalarArrayTy);
         Elts.push_back(EltGV);
         Elts.push_back(EltGV);
       }
       }
     } else
     } else

+ 21 - 0
tools/clang/test/HLSLFileCheck/hlsl/types/modifiers/groupshared/alignment/group_share_align2.hlsl

@@ -0,0 +1,21 @@
+// RUN: %dxc -E main -T cs_6_0 %s | FileCheck %s
+
+// Make sure alignment is 4.
+// CHECK:@{{.*}} = addrspace(3) global [8 x float] undef, align 4
+// CHECK:store float {{.*}}, float addrspace(3)* {{.*}}, align 4
+// CHECK:store float {{.*}}, float addrspace(3)* {{.*}}, align 4
+// CHECK:load float, float addrspace(3)* {{.*}}, align 4
+// CHECK:load float, float addrspace(3)* {{.*}}, align 4
+
+struct S {
+   float2 a;
+};
+
+groupshared S a[4];
+RWBuffer<float2> u;
+[numthreads(8,8,1)]
+void main(uint3 tid : SV_DispatchThreadID) {
+  a[tid.x].a = tid.y;
+  GroupMemoryBarrier();
+  u[tid.y] = a[tid.y].a;
+}