Browse Source

Struct for normal buf (#43)


* Support sturct in none-struct buffer.

* More tests for sturct in none-struct buffer.

* Simplify the test and use them.
Xiang Li 8 years ago
parent
commit
86b7e6ef9d

+ 5 - 0
include/dxc/HLSL/DxilConstants.h

@@ -626,6 +626,11 @@ namespace DXIL {
     // DomainLocation.
     const unsigned kDomainLocationColOpIdx = 1;
 
+    // BufferLoad.
+    const unsigned kBufferLoadHandleOpIdx = 1;
+    const unsigned kBufferLoadCoord0OpIdx = 2;
+    const unsigned kBufferLoadCoord1OpIdx = 3;
+
     // BufferStore.
     const unsigned kBufferStoreHandleOpIdx = 1;
     const unsigned kBufferStoreCoord0OpIdx = 2;

+ 39 - 6
lib/HLSL/HLOperationLower.cpp

@@ -3041,14 +3041,14 @@ void TranslateAtomicCmpXChg(AtomicHelper &helper, IRBuilder<> &Builder,
     _Analysis_assume_(vectorNumElements <= 3);
     for (unsigned i = 0; i < vectorNumElements; i++) {
       Value *Elt = Builder.CreateExtractElement(addr, i);
-      args[DXIL::OperandIndex::kAtomicBinOpCoord0OpIdx + i] = Elt;
+      args[DXIL::OperandIndex::kAtomicCmpExchangeCoord0OpIdx + i] = Elt;
     }
   } else
-    args[DXIL::OperandIndex::kAtomicBinOpCoord0OpIdx] = addr;
+    args[DXIL::OperandIndex::kAtomicCmpExchangeCoord0OpIdx] = addr;
 
   // Set offset for structured buffer.
   if (helper.offset)
-    args[DXIL::OperandIndex::kAtomicBinOpCoord1OpIdx] = helper.offset;
+    args[DXIL::OperandIndex::kAtomicCmpExchangeCoord1OpIdx] = helper.offset;
 
   Value *origVal = Builder.CreateCall(dxilAtomic, args);
   if (helper.originalValue) {
@@ -5629,10 +5629,43 @@ void TranslateHLSubscript(CallInst *CI, HLSubscriptOpcode opcode,
       Value *handle = pObjHelper->handleMap[ptrInst];
       DXIL::ResourceKind RK = pObjHelper->GetRK(ptrInst->getType());
       Translated = true;
-      if (RK == DxilResource::Kind::StructuredBuffer)
-        TranslateStructBufSubscript(CI, handle, /*status*/ nullptr, hlslOP, helper.legacyDataLayout);
-      else
+      Type *ObjTy = ptrInst->getType();
+      Type *RetTy = ObjTy->getStructElementType(0);
+      if (RK == DxilResource::Kind::StructuredBuffer) {
+        TranslateStructBufSubscript(CI, handle, /*status*/ nullptr, hlslOP,
+                                    helper.legacyDataLayout);
+      } else if (RetTy->isAggregateType() &&
+                 RK == DxilResource::Kind::TypedBuffer) {
+        TranslateStructBufSubscript(CI, handle, /*status*/ nullptr, hlslOP,
+                                    helper.legacyDataLayout);
+        // Clear offset for typed buf.
+        for (auto User : handle->users()) {
+          CallInst *CI = cast<CallInst>(User);
+          switch (hlslOP->GetDxilOpFuncCallInst(CI)) {
+          case DXIL::OpCode::BufferLoad: {
+            CI->setArgOperand(DXIL::OperandIndex::kBufferLoadCoord1OpIdx,
+                              UndefValue::get(helper.i32Ty));
+          } break;
+          case DXIL::OpCode::BufferStore: {
+            CI->setArgOperand(DXIL::OperandIndex::kBufferStoreCoord1OpIdx,
+                              UndefValue::get(helper.i32Ty));
+          } break;
+          case DXIL::OpCode::AtomicBinOp: {
+            CI->setArgOperand(DXIL::OperandIndex::kAtomicBinOpCoord1OpIdx,
+                              UndefValue::get(helper.i32Ty));
+          } break;
+          case DXIL::OpCode::AtomicCompareExchange: {
+            CI->setArgOperand(DXIL::OperandIndex::kAtomicCmpExchangeCoord1OpIdx,
+                              UndefValue::get(helper.i32Ty));
+          } break;
+          default:
+            DXASSERT(0, "Invalid operation on resource handle");
+            break;
+          }
+        }
+      } else {
         TranslateDefaultSubscript(CI, helper, pObjHelper, Translated);
+      }
       return;
     }
   }

+ 103 - 4
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -1888,6 +1888,72 @@ uint32_t CGMSHLSLRuntime::AddSampler(VarDecl *samplerDecl) {
   return m_pHLModule->AddSampler(std::move(hlslRes));
 }
 
+static void CollectScalarTypes(std::vector<llvm::Type *> &scalarTys, llvm::Type *Ty) {
+  if (llvm::StructType *ST = dyn_cast<llvm::StructType>(Ty)) {
+    for (llvm::Type *EltTy : ST->elements()) {
+      CollectScalarTypes(scalarTys, EltTy);
+    }
+  } else if (llvm::ArrayType *AT = dyn_cast<llvm::ArrayType>(Ty)) {
+    llvm::Type *EltTy = AT->getElementType();
+    for (unsigned i=0;i<AT->getNumElements();i++) {
+      CollectScalarTypes(scalarTys, EltTy);
+    }
+  } else if (llvm::VectorType *VT = dyn_cast<llvm::VectorType>(Ty)) {
+    llvm::Type *EltTy = VT->getElementType();
+    for (unsigned i=0;i<VT->getNumElements();i++) {
+      CollectScalarTypes(scalarTys, EltTy);
+    }
+  } else {
+    scalarTys.emplace_back(Ty);
+  }
+}
+
+
+static void CollectScalarTypes(std::vector<QualType> &ScalarTys, QualType Ty) {
+  if (Ty->isRecordType()) {
+    if (hlsl::IsHLSLMatType(Ty)) {
+      QualType EltTy = hlsl::GetHLSLMatElementType(Ty);
+      unsigned row = 0;
+      unsigned col = 0;
+      hlsl::GetRowsAndCols(Ty, row, col);
+      unsigned size = col*row;
+      for (unsigned i = 0; i < size; i++) {
+        CollectScalarTypes(ScalarTys, EltTy);
+      }
+    } else if (hlsl::IsHLSLVecType(Ty)) {
+      QualType EltTy = hlsl::GetHLSLVecElementType(Ty);
+      unsigned row = 0;
+      unsigned col = 0;
+      hlsl::GetRowsAndColsForAny(Ty, row, col);
+      unsigned size = col;
+      for (unsigned i = 0; i < size; i++) {
+        CollectScalarTypes(ScalarTys, EltTy);
+      }
+    } else {
+      const RecordType *RT = Ty->getAsStructureType();
+      // For CXXRecord.
+      if (!RT)
+        RT = Ty->getAs<RecordType>();
+      RecordDecl *RD = RT->getDecl();
+      for (FieldDecl *field : RD->fields())
+        CollectScalarTypes(ScalarTys, field->getType());
+    }
+  } else if (Ty->isArrayType()) {
+    const clang::ArrayType *AT = Ty->getAsArrayTypeUnsafe();
+    QualType EltTy = AT->getElementType();
+    // Set it to 5 for unsized array.
+    unsigned size = 5;
+    if (AT->isConstantArrayType()) {
+      size = cast<ConstantArrayType>(AT)->getSize().getLimitedValue();
+    }
+    for (unsigned i=0;i<size;i++) {
+      CollectScalarTypes(ScalarTys, EltTy);
+    }
+  } else {
+    ScalarTys.emplace_back(Ty);
+  }
+}
+
 uint32_t CGMSHLSLRuntime::AddUAVSRV(VarDecl *decl,
                                     hlsl::DxilResourceBase::Class resClass) {
   llvm::GlobalVariable *val =
@@ -1968,11 +2034,43 @@ uint32_t CGMSHLSLRuntime::AddUAVSRV(VarDecl *decl,
   if (kind != hlsl::DxilResource::Kind::StructuredBuffer) {
     QualType Ty = resultTy;
     QualType EltTy = Ty;
-    if (hlsl::IsHLSLMatType(Ty))
+    if (hlsl::IsHLSLVecType(Ty)) {
+      EltTy = hlsl::GetHLSLVecElementType(Ty);
+    } else if (hlsl::IsHLSLMatType(Ty)) {
       EltTy = hlsl::GetHLSLMatElementType(Ty);
+    } else if (resultTy->isAggregateType()) {
+      // Struct or array in a none-struct resource.
+      std::vector<QualType> ScalarTys;
+      CollectScalarTypes(ScalarTys, resultTy);
+      unsigned size = ScalarTys.size();
+      if (size == 0) {
+        DiagnosticsEngine &Diags = CGM.getDiags();
+        unsigned DiagID = Diags.getCustomDiagID(
+            DiagnosticsEngine::Error, "object's templated type must have at least one element");
+        Diags.Report(decl->getLocation(), DiagID);
+        return 0;
+      }
+      if (size > 4) {
+        DiagnosticsEngine &Diags = CGM.getDiags();
+        unsigned DiagID = Diags.getCustomDiagID(
+            DiagnosticsEngine::Error, "elements of typed buffers and textures "
+                                      "must fit in four 32-bit quantities");
+        Diags.Report(decl->getLocation(), DiagID);
+        return 0;
+      }
 
-    if (hlsl::IsHLSLVecType(Ty))
-      EltTy = hlsl::GetHLSLVecElementType(Ty);
+      EltTy = ScalarTys[0];
+      for (QualType ScalarTy : ScalarTys) {
+        if (ScalarTy != EltTy) {
+          DiagnosticsEngine &Diags = CGM.getDiags();
+          unsigned DiagID = Diags.getCustomDiagID(
+              DiagnosticsEngine::Error,
+              "all template type components must have the same type");
+          Diags.Report(decl->getLocation(), DiagID);
+          return 0;
+        }
+      }
+    }
 
     EltTy = EltTy.getCanonicalType();
     bool bSNorm = false;
@@ -1996,8 +2094,9 @@ uint32_t CGMSHLSLRuntime::AddUAVSRV(VarDecl *decl,
       const BuiltinType *BTy = EltTy->getAs<BuiltinType>();
       CompType::Kind kind = BuiltinTyToCompTy(BTy, bSNorm, bUNorm);
       hlslRes->SetCompType(kind);
-    } else
+    } else {
       DXASSERT(!bSNorm && !bUNorm, "snorm/unorm on invalid type");
+    }
   }
   // TODO: set resource
   // hlslRes.SetGloballyCoherent();

+ 20 - 0
tools/clang/test/CodeGenHLSL/BigStructInBuffer.hlsl

@@ -0,0 +1,20 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// CHECK: elements of typed buffers and textures must fit in four 32-bit quantities
+
+struct Foo {
+  float2 a;
+  float b;
+  float c;
+  float d[2];
+};
+
+Buffer<Foo> inputs : register(t1);
+
+RWBuffer< int > g_Intensities : register(u1);
+
+[ numthreads( 64, 2, 2 ) ]
+void main( uint GI : SV_GroupIndex)
+{
+	g_Intensities = inputs[GI].d[0];
+}

+ 13 - 0
tools/clang/test/CodeGenHLSL/EmptyStructInBuffer.hlsl

@@ -0,0 +1,13 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// CHECK: object's templated type must have at least one element
+
+struct Empty {};
+Buffer<Empty> eb;
+
+
+[ numthreads( 64, 2, 2 ) ]
+void main( uint GI : SV_GroupIndex)
+{
+        eb[GI];
+}

+ 22 - 0
tools/clang/test/CodeGenHLSL/structInBuffer.hlsl

@@ -0,0 +1,22 @@
+// RUN: %dxc -E main  -T cs_6_0 %s
+
+struct Foo {
+    int a;
+    int b;
+    int c;
+    int d;
+};
+
+Buffer<Foo> inputs : register(t1);
+RWBuffer< int > g_Intensities : register(u1);
+
+groupshared Foo sharedData;
+
+[ numthreads( 64, 2, 2 ) ]
+void main( uint GI : SV_GroupIndex)
+{
+	sharedData = inputs[GI];
+	int rtn;
+	InterlockedAdd(sharedData.d, g_Intensities[GI], rtn);
+	g_Intensities[GI] = rtn + sharedData.d;
+}

+ 20 - 0
tools/clang/test/CodeGenHLSL/structInBuffer2.hlsl

@@ -0,0 +1,20 @@
+// RUN: %dxc -E main  -T cs_6_0 %s
+
+struct Foo {
+    int a[2];
+    int d[2];
+};
+
+Buffer<Foo> inputs : register(t1);
+RWBuffer< int > g_Intensities : register(u1);
+
+groupshared Foo sharedData;
+
+[ numthreads( 64, 2, 2 ) ]
+void main( uint GI : SV_GroupIndex)
+{
+	sharedData = inputs[GI];
+	int rtn;
+	InterlockedAdd(sharedData.d[0], g_Intensities[GI], rtn);
+	g_Intensities[GI] = rtn + sharedData.d[0];
+}

+ 27 - 0
tools/clang/test/CodeGenHLSL/structInBuffer3.hlsl

@@ -0,0 +1,27 @@
+// RUN: %dxc -E main  -T cs_6_0 %s | FileCheck %s
+
+// CHECK: all template type components must have the same type
+
+struct Foo {
+    int a;
+    int b;
+    float c;
+    int d;
+};
+
+Buffer<Foo> inputs : register(t1);
+RWBuffer< int > g_Intensities : register(u1);
+
+groupshared Foo sharedData;
+
+#ifdef DX12
+[RootSignature("DescriptorTable(UAV(u1, numDescriptors=1), SRV(t1), visibility=SHADER_VISIBILITY_ALL)")]
+#endif
+[ numthreads( 64, 2, 2 ) ]
+void main( uint GI : SV_GroupIndex)
+{
+	sharedData = inputs[GI];
+	int rtn;
+	InterlockedAdd(sharedData.d, g_Intensities[GI], rtn);
+	g_Intensities[GI] = rtn + sharedData.d;
+}

+ 15 - 0
tools/clang/unittests/HLSL/CompilerTest.cpp

@@ -483,6 +483,9 @@ public:
   TEST_METHOD(CodeGenStruct_BufHasCounter2)
   TEST_METHOD(CodeGenStructCast)
   TEST_METHOD(CodeGenStructCast2)
+  TEST_METHOD(CodeGenStructInBuffer)
+  TEST_METHOD(CodeGenStructInBuffer2)
+  TEST_METHOD(CodeGenStructInBuffer3)
   TEST_METHOD(CodeGenSwitchFloat)
   TEST_METHOD(CodeGenSwitch1)
   TEST_METHOD(CodeGenSwitch2)
@@ -2430,6 +2433,18 @@ TEST_F(CompilerTest, CodeGenStructCast2) {
   CodeGenTest(L"..\\CodeGenHLSL\\StructCast2.hlsl");
 }
 
+TEST_F(CompilerTest, CodeGenStructInBuffer) {
+  CodeGenTest(L"..\\CodeGenHLSL\\structInBuffer.hlsl");
+}
+
+TEST_F(CompilerTest, CodeGenStructInBuffer2) {
+  CodeGenTest(L"..\\CodeGenHLSL\\structInBuffer2.hlsl");
+}
+
+TEST_F(CompilerTest, CodeGenStructInBuffer3) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\structInBuffer3.hlsl");
+}
+
 TEST_F(CompilerTest, CodeGenSwitchFloat) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\switch_float.hlsl");
 }

+ 10 - 0
tools/clang/unittests/HLSL/ValidationTest.cpp

@@ -123,6 +123,8 @@ public:
   TEST_METHOD(MultiDimArray)
   TEST_METHOD(NoFunctionParam)
   TEST_METHOD(I8Type)
+  TEST_METHOD(EmptyStructInBuffer)
+  TEST_METHOD(BigStructInBuffer)
 
   TEST_METHOD(ClipCullMaxComponents)
   TEST_METHOD(ClipCullMaxRows)
@@ -1289,6 +1291,14 @@ TEST_F(ValidationTest, I8Type) {
     /*bRegex*/true);
 }
 
+TEST_F(ValidationTest, EmptyStructInBuffer) {
+  TestCheck(L"..\\CodeGenHLSL\\EmptyStructInBuffer.hlsl");
+}
+
+TEST_F(ValidationTest, BigStructInBuffer) {
+  TestCheck(L"..\\CodeGenHLSL\\BigStructInBuffer.hlsl");
+}
+
 TEST_F(ValidationTest, WhenWaveAffectsGradientThenFail) {
   TestCheck(L"val-wave-failures-ps.hlsl");
 }