Răsfoiți Sursa

Merged PR 75: Fix mask when lowering [RW]StructuredBuffer::Load used with built-in type

Fix mask when lowering [RW]StructuredBuffer::Load used with built-in type

- mask was 0 when using Load method on [RW]StructuredBuffer with a basic type as the template type (like float2).
- also, load would assert on double3 or double4, since it couldn't handle splitting the load.
- Changed to use GenerateStructBufLd function instead, which is used in the struct path.
- Added test cases.
Tex Riddell 7 ani în urmă
părinte
comite
6a3898a926
2 a modificat fișierele cu 158 adăugiri și 16 ștergeri
  1. 23 16
      lib/HLSL/HLOperationLower.cpp
  2. 135 0
      tools/clang/test/CodeGenHLSL/struct_buf2.hlsl

+ 23 - 16
lib/HLSL/HLOperationLower.cpp

@@ -3162,6 +3162,11 @@ static Constant *GetRawBufferMaskForETy(Type *Ty, unsigned NumComponents, hlsl::
   return OP->GetI8Const(mask);
   return OP->GetI8Const(mask);
 }
 }
 
 
+void GenerateStructBufLd(Value *handle, Value *bufIdx, Value *offset,
+  Value *status, Type *EltTy,
+  MutableArrayRef<Value *> resultElts, hlsl::OP *OP,
+  IRBuilder<> &Builder, unsigned NumComponents, Constant *alignment);
+
 void TranslateLoad(ResLoadHelper &helper, HLResource::Kind RK,
 void TranslateLoad(ResLoadHelper &helper, HLResource::Kind RK,
                    IRBuilder<> &Builder, hlsl::OP *OP, const DataLayout &DL) {
                    IRBuilder<> &Builder, hlsl::OP *OP, const DataLayout &DL) {
 
 
@@ -3179,6 +3184,22 @@ void TranslateLoad(ResLoadHelper &helper, HLResource::Kind RK,
   Type *doubleTy = Builder.getDoubleTy();
   Type *doubleTy = Builder.getDoubleTy();
   Type *EltTy = Ty->getScalarType();
   Type *EltTy = Ty->getScalarType();
   Constant *Alignment = OP->GetI32Const(OP->GetAllocSizeForType(EltTy));
   Constant *Alignment = OP->GetI32Const(OP->GetAllocSizeForType(EltTy));
+  unsigned numComponents = 1;
+  if (Ty->isVectorTy()) {
+    numComponents = Ty->getVectorNumElements();
+  }
+
+  if (RK == HLResource::Kind::StructuredBuffer) {
+    // Basic type case for StructuredBuffer::Load()
+    Value *ResultElts[4];
+    GenerateStructBufLd(helper.handle, helper.addr, OP->GetU32Const(0),
+      helper.status, EltTy, ResultElts, OP, Builder, numComponents, Alignment);
+    Value *retValNew = ScalarizeElements(Ty, ResultElts, Builder);
+    helper.retVal->replaceAllUsesWith(retValNew);
+    helper.retVal = retValNew;
+    return;
+  }
+
   bool is64 = EltTy == i64Ty || EltTy == doubleTy;
   bool is64 = EltTy == i64Ty || EltTy == doubleTy;
   if (is64) {
   if (is64) {
     EltTy = i32Ty;
     EltTy = i32Ty;
@@ -3246,24 +3267,13 @@ void TranslateLoad(ResLoadHelper &helper, HLResource::Kind RK,
     // elementOffset, mask, alignment
     // elementOffset, mask, alignment
     loadArgs.emplace_back(undefI);
     loadArgs.emplace_back(undefI);
     Type *rtnTy = helper.retVal->getType();
     Type *rtnTy = helper.retVal->getType();
-    unsigned numComponents = 1;
-    if (VectorType *VTy = dyn_cast<VectorType>(rtnTy)) {
-      rtnTy = VTy->getElementType();
-      numComponents = VTy->getNumElements();
-    }
     loadArgs.emplace_back(GetRawBufferMaskForETy(rtnTy, numComponents, OP));
     loadArgs.emplace_back(GetRawBufferMaskForETy(rtnTy, numComponents, OP));
     loadArgs.emplace_back(Alignment);
     loadArgs.emplace_back(Alignment);
   }
   }
   else if (RK == DxilResource::Kind::TypedBuffer) {
   else if (RK == DxilResource::Kind::TypedBuffer) {
     loadArgs.emplace_back(undefI);
     loadArgs.emplace_back(undefI);
   }
   }
-  else if (RK == DxilResource::Kind::StructuredBuffer) {
-    // elementOffset, mask, alignment
-    loadArgs.emplace_back(
-      OP->GetU32Const(0)); // For case use built-in types in structure buffer.
-    loadArgs.emplace_back(OP->GetU8Const(0)); // When is this case hit?
-    loadArgs.emplace_back(Alignment);
-  }
+
   Value *ResRet =
   Value *ResRet =
       Builder.CreateCall(F, loadArgs, OP->GetOpCodeName(opcode));
       Builder.CreateCall(F, loadArgs, OP->GetOpCodeName(opcode));
 
 
@@ -3271,10 +3281,7 @@ void TranslateLoad(ResLoadHelper &helper, HLResource::Kind RK,
   if (!is64) {
   if (!is64) {
     retValNew = ScalarizeResRet(Ty, ResRet, Builder);
     retValNew = ScalarizeResRet(Ty, ResRet, Builder);
   } else {
   } else {
-    unsigned size = 1;
-    if (Ty->isVectorTy()) {
-      size = Ty->getVectorNumElements();
-    }
+    unsigned size = numComponents;
     DXASSERT(size <= 2, "typed buffer only allow 4 dwords");
     DXASSERT(size <= 2, "typed buffer only allow 4 dwords");
     EltTy = Ty->getScalarType();
     EltTy = Ty->getScalarType();
     Value *Elts[2];
     Value *Elts[2];

+ 135 - 0
tools/clang/test/CodeGenHLSL/struct_buf2.hlsl

@@ -71,6 +71,98 @@ struct MyStruct {
 };
 };
 StructuredBuffer<MyStruct> buf1;
 StructuredBuffer<MyStruct> buf1;
 RWStructuredBuffer<MyStruct> buf2;
 RWStructuredBuffer<MyStruct> buf2;
+
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle %buf1_i1_texture_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 1, i32 4)
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle %buf1_i2_texture_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 3, i32 4)
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle %buf1_i3_texture_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 7, i32 4)
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle %buf1_i4_texture_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 15, i32 4)
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle %buf1_u1_texture_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 1, i32 4)
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle %buf1_u2_texture_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 3, i32 4)
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle %buf1_u3_texture_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 7, i32 4)
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle %buf1_u4_texture_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 15, i32 4)
+// CHECK: call %dx.types.ResRet.f32 @dx.op.rawBufferLoad.f32(i32 139, %dx.types.Handle %buf1_h1_texture_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 1, i32 4)
+// CHECK: call %dx.types.ResRet.f32 @dx.op.rawBufferLoad.f32(i32 139, %dx.types.Handle %buf1_h2_texture_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 3, i32 4)
+// CHECK: call %dx.types.ResRet.f32 @dx.op.rawBufferLoad.f32(i32 139, %dx.types.Handle %buf1_h3_texture_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 7, i32 4)
+// CHECK: call %dx.types.ResRet.f32 @dx.op.rawBufferLoad.f32(i32 139, %dx.types.Handle %buf1_h4_texture_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 15, i32 4)
+// CHECK: call %dx.types.ResRet.f32 @dx.op.rawBufferLoad.f32(i32 139, %dx.types.Handle %buf1_f1_texture_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 1, i32 4)
+// CHECK: call %dx.types.ResRet.f32 @dx.op.rawBufferLoad.f32(i32 139, %dx.types.Handle %buf1_f2_texture_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 3, i32 4)
+// CHECK: call %dx.types.ResRet.f32 @dx.op.rawBufferLoad.f32(i32 139, %dx.types.Handle %buf1_f3_texture_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 7, i32 4)
+// CHECK: call %dx.types.ResRet.f32 @dx.op.rawBufferLoad.f32(i32 139, %dx.types.Handle %buf1_f4_texture_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 15, i32 4)
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle %buf1_d1_texture_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 3, i32 8)
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle %buf1_d2_texture_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 15, i32 8)
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle %buf1_d3_texture_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 15, i32 8)
+// second half of double3
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle %buf1_d3_texture_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 16, i8 3, i32 8)
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle %buf1_d4_texture_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 15, i32 8)
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle %buf1_d4_texture_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 16, i8 15, i32 8)
+
+StructuredBuffer<int> buf1_i1;
+StructuredBuffer<int2> buf1_i2;
+StructuredBuffer<int3> buf1_i3;
+StructuredBuffer<int4> buf1_i4;
+StructuredBuffer<uint> buf1_u1;
+StructuredBuffer<uint2> buf1_u2;
+StructuredBuffer<uint3> buf1_u3;
+StructuredBuffer<uint4> buf1_u4;
+StructuredBuffer<half> buf1_h1;
+StructuredBuffer<half2> buf1_h2;
+StructuredBuffer<half3> buf1_h3;
+StructuredBuffer<half4> buf1_h4;
+StructuredBuffer<float> buf1_f1;
+StructuredBuffer<float2> buf1_f2;
+StructuredBuffer<float3> buf1_f3;
+StructuredBuffer<float4> buf1_f4;
+StructuredBuffer<double> buf1_d1;
+StructuredBuffer<double2> buf1_d2;
+StructuredBuffer<double3> buf1_d3;
+StructuredBuffer<double4> buf1_d4;
+
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle %buf2_i1_UAV_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 1, i32 4)
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle %buf2_i2_UAV_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 3, i32 4)
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle %buf2_i3_UAV_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 7, i32 4)
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle %buf2_i4_UAV_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 15, i32 4)
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle %buf2_u1_UAV_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 1, i32 4)
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle %buf2_u2_UAV_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 3, i32 4)
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle %buf2_u3_UAV_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 7, i32 4)
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle %buf2_u4_UAV_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 15, i32 4)
+// CHECK: call %dx.types.ResRet.f32 @dx.op.rawBufferLoad.f32(i32 139, %dx.types.Handle %buf2_h1_UAV_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 1, i32 4)
+// CHECK: call %dx.types.ResRet.f32 @dx.op.rawBufferLoad.f32(i32 139, %dx.types.Handle %buf2_h2_UAV_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 3, i32 4)
+// CHECK: call %dx.types.ResRet.f32 @dx.op.rawBufferLoad.f32(i32 139, %dx.types.Handle %buf2_h3_UAV_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 7, i32 4)
+// CHECK: call %dx.types.ResRet.f32 @dx.op.rawBufferLoad.f32(i32 139, %dx.types.Handle %buf2_h4_UAV_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 15, i32 4)
+// CHECK: call %dx.types.ResRet.f32 @dx.op.rawBufferLoad.f32(i32 139, %dx.types.Handle %buf2_f1_UAV_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 1, i32 4)
+// CHECK: call %dx.types.ResRet.f32 @dx.op.rawBufferLoad.f32(i32 139, %dx.types.Handle %buf2_f2_UAV_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 3, i32 4)
+// CHECK: call %dx.types.ResRet.f32 @dx.op.rawBufferLoad.f32(i32 139, %dx.types.Handle %buf2_f3_UAV_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 7, i32 4)
+// CHECK: call %dx.types.ResRet.f32 @dx.op.rawBufferLoad.f32(i32 139, %dx.types.Handle %buf2_f4_UAV_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 15, i32 4)
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle %buf2_d1_UAV_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 3, i32 8)
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle %buf2_d2_UAV_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 15, i32 8)
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle %buf2_d3_UAV_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 15, i32 8)
+// second half of double3
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle %buf2_d3_UAV_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 16, i8 3, i32 8)
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle %buf2_d4_UAV_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 0, i8 15, i32 8)
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle %buf2_d4_UAV_structbuf, i32 %{{[a-zA-Z0-9]+}}, i32 16, i8 15, i32 8)
+
+RWStructuredBuffer<int> buf2_i1;
+RWStructuredBuffer<int2> buf2_i2;
+RWStructuredBuffer<int3> buf2_i3;
+RWStructuredBuffer<int4> buf2_i4;
+RWStructuredBuffer<uint> buf2_u1;
+RWStructuredBuffer<uint2> buf2_u2;
+RWStructuredBuffer<uint3> buf2_u3;
+RWStructuredBuffer<uint4> buf2_u4;
+RWStructuredBuffer<half> buf2_h1;
+RWStructuredBuffer<half2> buf2_h2;
+RWStructuredBuffer<half3> buf2_h3;
+RWStructuredBuffer<half4> buf2_h4;
+RWStructuredBuffer<float> buf2_f1;
+RWStructuredBuffer<float2> buf2_f2;
+RWStructuredBuffer<float3> buf2_f3;
+RWStructuredBuffer<float4> buf2_f4;
+RWStructuredBuffer<double> buf2_d1;
+RWStructuredBuffer<double2> buf2_d2;
+RWStructuredBuffer<double3> buf2_d3;
+RWStructuredBuffer<double4> buf2_d4;
+
+
 int4 main(float idx1 : IDX1, float idx2 : IDX2) : SV_Target {
 int4 main(float idx1 : IDX1, float idx2 : IDX2) : SV_Target {
   uint status;
   uint status;
   float4 r = 0;
   float4 r = 0;
@@ -116,6 +208,49 @@ int4 main(float idx1 : IDX1, float idx2 : IDX2) : SV_Target {
   r.xyz += buf2.Load(idx2, status).d3;
   r.xyz += buf2.Load(idx2, status).d3;
   r.xyzw += buf2.Load(idx2, status).d4;
   r.xyzw += buf2.Load(idx2, status).d4;
 
 
+  // Basic types
+  r.x += buf1_i1.Load(idx1, status);
+  r.xy += buf1_i2.Load(idx1, status);
+  r.xyz += buf1_i3.Load(idx1, status);
+  r.xyzw += buf1_i4.Load(idx1, status);
+  r.x += buf1_u1.Load(idx1, status);
+  r.xy += buf1_u2.Load(idx1, status);
+  r.xyz += buf1_u3.Load(idx1, status);
+  r.xyzw += buf1_u4.Load(idx1, status);
+  r.x += buf1_h1.Load(idx1, status);
+  r.xy += buf1_h2.Load(idx1, status);
+  r.xyz += buf1_h3.Load(idx1, status);
+  r.xyzw += buf1_h4.Load(idx1, status);
+  r.x += buf1_f1.Load(idx1, status);
+  r.xy += buf1_f2.Load(idx1, status);
+  r.xyz += buf1_f3.Load(idx1, status);
+  r.xyzw += buf1_f4.Load(idx1, status);
+  r.x += buf1_d1.Load(idx1, status);
+  r.xy += buf1_d2.Load(idx1, status);
+  r.xyz += buf1_d3.Load(idx1, status);
+  r.xyzw += buf1_d4.Load(idx1, status);
+
+  r.x += buf2_i1.Load(idx2, status);
+  r.xy += buf2_i2.Load(idx2, status);
+  r.xyz += buf2_i3.Load(idx2, status);
+  r.xyzw += buf2_i4.Load(idx2, status);
+  r.x += buf2_u1.Load(idx2, status);
+  r.xy += buf2_u2.Load(idx2, status);
+  r.xyz += buf2_u3.Load(idx2, status);
+  r.xyzw += buf2_u4.Load(idx2, status);
+  r.x += buf2_h1.Load(idx2, status);
+  r.xy += buf2_h2.Load(idx2, status);
+  r.xyz += buf2_h3.Load(idx2, status);
+  r.xyzw += buf2_h4.Load(idx2, status);
+  r.x += buf2_f1.Load(idx2, status);
+  r.xy += buf2_f2.Load(idx2, status);
+  r.xyz += buf2_f3.Load(idx2, status);
+  r.xyzw += buf2_f4.Load(idx2, status);
+  r.x += buf2_d1.Load(idx2, status);
+  r.xy += buf2_d2.Load(idx2, status);
+  r.xyz += buf2_d3.Load(idx2, status);
+  r.xyzw += buf2_d4.Load(idx2, status);
+
   buf2[0].f4 = r;
   buf2[0].f4 = r;
   return r;
   return r;
 }
 }