Quellcode durchsuchen

Add support for generalized templated byte address buffer loads (#2159)

Implements [RW]ByteAddressBuffer.Load<T> for all numerical or aggregate of numerical Ts, leveraging the same code from StructuredBuffer<T>.Load.
Tristan Labelle vor 6 Jahren
Ursprung
Commit
3a7e6d38de

+ 23 - 7
lib/HLSL/HLOperationLower.cpp

@@ -3213,7 +3213,7 @@ ResLoadHelper::ResLoadHelper(CallInst *CI, DxilResource::Kind RK,
 }
 }
 
 
 void TranslateStructBufSubscript(CallInst *CI, Value *handle, Value *status,
 void TranslateStructBufSubscript(CallInst *CI, Value *handle, Value *status,
-                                 hlsl::OP *OP, const DataLayout &DL);
+                                 hlsl::OP *OP, HLResource::Kind RK, const DataLayout &DL);
 
 
 // Create { v0, v1 } from { v0.lo, v0.hi, v1.lo, v1.hi }
 // Create { v0, v1 } from { v0.lo, v0.hi, v1.lo, v1.hi }
 void Make64bitResultForLoad(Type *EltTy, ArrayRef<Value *> resultElts32,
 void Make64bitResultForLoad(Type *EltTy, ArrayRef<Value *> resultElts32,
@@ -3280,7 +3280,7 @@ void TranslateLoad(ResLoadHelper &helper, HLResource::Kind RK,
   if (Ty->isPointerTy()) {
   if (Ty->isPointerTy()) {
     DXASSERT(!DxilResource::IsAnyTexture(RK), "Textures should not be treated as structured buffers.");
     DXASSERT(!DxilResource::IsAnyTexture(RK), "Textures should not be treated as structured buffers.");
     TranslateStructBufSubscript(cast<CallInst>(helper.retVal), helper.handle,
     TranslateStructBufSubscript(cast<CallInst>(helper.retVal), helper.handle,
-                                helper.status, OP, DL);
+                                helper.status, OP, RK, DL);
     return;
     return;
   }
   }
 
 
@@ -5925,6 +5925,13 @@ Value *GenerateStructBufLd(Value *handle, Value *bufIdx, Value *offset,
   DXASSERT(resultElts.size() <= 4,
   DXASSERT(resultElts.size() <= 4,
            "buffer load cannot load more than 4 values");
            "buffer load cannot load more than 4 values");
 
 
+  if (bufIdx == nullptr) {
+    // This is actually a byte address buffer load with a struct template type.
+    // The call takes only one coordinates for the offset.
+    bufIdx = offset;
+    offset = UndefValue::get(offset->getType());
+  }
+
   Function *dxilF = OP->GetOpFunc(opcode, EltTy);
   Function *dxilF = OP->GetOpFunc(opcode, EltTy);
   Constant *mask = GetRawBufferMaskForETy(EltTy, NumComponents, OP);
   Constant *mask = GetRawBufferMaskForETy(EltTy, NumComponents, OP);
   Value *Args[] = {OP->GetU32Const((unsigned)opcode),
   Value *Args[] = {OP->GetU32Const((unsigned)opcode),
@@ -6448,14 +6455,23 @@ void TranslateStructBufSubscriptUser(Instruction *user, Value *handle,
 }
 }
 
 
 void TranslateStructBufSubscript(CallInst *CI, Value *handle, Value *status,
 void TranslateStructBufSubscript(CallInst *CI, Value *handle, Value *status,
-                                 hlsl::OP *OP, const DataLayout &DL) {
-  Value *bufIdx = CI->getArgOperand(HLOperandIndex::kSubscriptIndexOpIdx);
+                                 hlsl::OP *OP, HLResource::Kind ResKind, const DataLayout &DL) {
+  Value *subscriptIndex = CI->getArgOperand(HLOperandIndex::kSubscriptIndexOpIdx);
+  Value* bufIdx = nullptr;
+  Value *offset = nullptr;
+  if (ResKind == HLResource::Kind::RawBuffer) {
+    offset = subscriptIndex;
+  }
+  else {
+    // StructuredBuffer, TypedBuffer, etc.
+    bufIdx = subscriptIndex;
+  }
 
 
   for (auto U = CI->user_begin(); U != CI->user_end();) {
   for (auto U = CI->user_begin(); U != CI->user_end();) {
     Value *user = *(U++);
     Value *user = *(U++);
 
 
     TranslateStructBufSubscriptUser(cast<Instruction>(user), handle, bufIdx,
     TranslateStructBufSubscriptUser(cast<Instruction>(user), handle, bufIdx,
-                                    /*baseOffset*/ nullptr, status, OP, DL);
+                                    offset, status, OP, DL);
   }
   }
 }
 }
 }
 }
@@ -6802,11 +6818,11 @@ void TranslateHLSubscript(CallInst *CI, HLSubscriptOpcode opcode,
       Type *ObjTy = pObjHelper->GetResourceType(handle);
       Type *ObjTy = pObjHelper->GetResourceType(handle);
       Type *RetTy = ObjTy->getStructElementType(0);
       Type *RetTy = ObjTy->getStructElementType(0);
       if (RK == DxilResource::Kind::StructuredBuffer) {
       if (RK == DxilResource::Kind::StructuredBuffer) {
-        TranslateStructBufSubscript(CI, handle, /*status*/ nullptr, hlslOP,
+        TranslateStructBufSubscript(CI, handle, /*status*/ nullptr, hlslOP, RK,
                                     helper.dataLayout);
                                     helper.dataLayout);
       } else if (RetTy->isAggregateType() &&
       } else if (RetTy->isAggregateType() &&
                  RK == DxilResource::Kind::TypedBuffer) {
                  RK == DxilResource::Kind::TypedBuffer) {
-        TranslateStructBufSubscript(CI, handle, /*status*/ nullptr, hlslOP,
+        TranslateStructBufSubscript(CI, handle, /*status*/ nullptr, hlslOP, RK,
                                     helper.dataLayout);
                                     helper.dataLayout);
         // Clear offset for typed buf.
         // Clear offset for typed buf.
         for (auto User = handle->user_begin(); User != handle->user_end(); ) {
         for (auto User = handle->user_begin(); User != handle->user_end(); ) {

+ 4 - 4
tools/clang/include/clang/Basic/DiagnosticSemaKinds.td

@@ -7669,11 +7669,11 @@ def err_hlsl_unsupported_for_version_lower : Error<
 def err_hlsl_unsupported_keyword_for_min_precision : Error<
 def err_hlsl_unsupported_keyword_for_min_precision : Error<
    "%0 is only supported with -enable-16bit-types option">;
    "%0 is only supported with -enable-16bit-types option">;
 def err_hlsl_intrinsic_template_arg_unsupported: Error<
 def err_hlsl_intrinsic_template_arg_unsupported: Error<
-   "Explicit template arguments on intrinsic %0 are not supported.">;
+   "Explicit template arguments on intrinsic %0 are not supported">;
 def err_hlsl_intrinsic_template_arg_requires_2018: Error<
 def err_hlsl_intrinsic_template_arg_requires_2018: Error<
-   "Explicit template arguments on intrinsic %0 requires HLSL version 2018 or above.">;
-def err_hlsl_intrinsic_template_arg_scalar_vector: Error<
-   "Explicit template arguments on intrinsic %0 are limited one to scalar or vector type.">;
+   "Explicit template arguments on intrinsic %0 requires HLSL version 2018 or above">;
+def err_hlsl_intrinsic_template_arg_numeric: Error<
+   "Explicit template arguments on intrinsic %0 must be a single numeric type">;
 }
 }
 def err_hlsl_no_struct_user_defined_type: Error<
 def err_hlsl_no_struct_user_defined_type: Error<
    "User defined type intrinsic arg must be struct">;
    "User defined type intrinsic arg must be struct">;

+ 10 - 7
tools/clang/lib/Sema/SemaHLSL.cpp

@@ -9031,14 +9031,17 @@ Sema::TemplateDeductionResult HLSLExternalSource::DeduceTemplateArgumentsForHLSL
           !IsBABLoad
           !IsBABLoad
               ? diag::err_hlsl_intrinsic_template_arg_unsupported
               ? diag::err_hlsl_intrinsic_template_arg_unsupported
               : !Is2018 ? diag::err_hlsl_intrinsic_template_arg_requires_2018
               : !Is2018 ? diag::err_hlsl_intrinsic_template_arg_requires_2018
-                        : diag::err_hlsl_intrinsic_template_arg_scalar_vector;
+                        : diag::err_hlsl_intrinsic_template_arg_numeric;
       if (IsBABLoad && Is2018 && ExplicitTemplateArgs->size() == 1) {
       if (IsBABLoad && Is2018 && ExplicitTemplateArgs->size() == 1) {
-        Loc = (*ExplicitTemplateArgs)[0].getLocation();
-        QualType explicitType = (*ExplicitTemplateArgs)[0].getArgument().getAsType();
-        ArTypeObjectKind explicitKind = GetTypeObjectKind(explicitType);
-        if (explicitKind == AR_TOBJ_BASIC || explicitKind == AR_TOBJ_VECTOR) {
-          isLegalTemplate = true;
-          argTypes[0] = explicitType;
+        const TemplateArgumentLoc& TemplateArgLoc = (*ExplicitTemplateArgs)[0];
+        Loc = TemplateArgLoc.getLocation();
+        if (TemplateArgLoc.getArgument().getKind() == TemplateArgument::ArgKind::Type) {
+          QualType explicitType = TemplateArgLoc.getArgument().getAsType();
+          ArTypeObjectKind explicitKind = GetTypeObjectKind(explicitType);
+          if (hlsl::IsHLSLNumericOrAggregateOfNumericType(explicitType)) {
+            isLegalTemplate = true;
+            argTypes[0] = explicitType;
+          }
         }
         }
       }
       }
 
 

+ 38 - 0
tools/clang/test/CodeGenHLSL/batch/declarations/resources/byteaddressbuffers/load_type_shapes_sm60.hlsl

@@ -0,0 +1,38 @@
+// RUN: %dxc -E main -T vs_6_0 %s | FileCheck %s
+
+// Tests that ByteAddressBuffer.Load<T> works with all type shapes with SM 6.0
+
+struct S { int i; float f; };
+ByteAddressBuffer buf;
+RWStructuredBuffer<int> out_scalar;
+RWStructuredBuffer<int2> out_vector;
+RWStructuredBuffer<int2x2> out_matrix;
+RWStructuredBuffer<int[2]> out_array;
+RWStructuredBuffer<S> out_struct;
+RWStructuredBuffer<S[2]> out_struct_array;
+
+void main() {
+  // CHECK: call %dx.types.ResRet.i32 @dx.op.bufferLoad.i32(i32 68, %dx.types.Handle {{.*}}, i32 100, i32 undef)
+  out_scalar[0] = buf.Load<int>(100);
+  
+  // CHECK: call %dx.types.ResRet.i32 @dx.op.bufferLoad.i32(i32 68, %dx.types.Handle {{.*}}, i32 200, i32 undef)
+  out_vector[0] = buf.Load<int2>(200);
+  
+  // CHECK: call %dx.types.ResRet.i32 @dx.op.bufferLoad.i32(i32 68, %dx.types.Handle {{.*}}, i32 300, i32 undef)
+  out_matrix[0] = buf.Load<int2x2>(300);
+
+  // CHECK: call %dx.types.ResRet.i32 @dx.op.bufferLoad.i32(i32 68, %dx.types.Handle {{.*}}, i32 400, i32 undef)
+  // CHECK: call %dx.types.ResRet.i32 @dx.op.bufferLoad.i32(i32 68, %dx.types.Handle {{.*}}, i32 404, i32 undef)
+  out_array[0] = buf.Load<int[2]>(400);
+  
+  // CHECK: call %dx.types.ResRet.i32 @dx.op.bufferLoad.i32(i32 68, %dx.types.Handle {{.*}}, i32 500, i32 undef)
+  // CHECK: call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle {{.*}}, i32 504, i32 undef)
+  out_struct[0] = buf.Load<S>(500);
+  
+  // Test loads of arrays of structs because of the SROA behavior that turns them into per-element arrays
+  // CHECK: call %dx.types.ResRet.i32 @dx.op.bufferLoad.i32(i32 68, %dx.types.Handle {{.*}}, i32 600, i32 undef)
+  // CHECK: call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle {{.*}}, i32 604, i32 undef)
+  // CHECK: call %dx.types.ResRet.i32 @dx.op.bufferLoad.i32(i32 68, %dx.types.Handle {{.*}}, i32 608, i32 undef)
+  // CHECK: call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle {{.*}}, i32 612, i32 undef)
+  out_struct_array[0] = buf.Load<S[2]>(600);
+}

+ 38 - 0
tools/clang/test/CodeGenHLSL/batch/declarations/resources/byteaddressbuffers/load_type_shapes_sm62.hlsl

@@ -0,0 +1,38 @@
+// RUN: %dxc -E main -T vs_6_2 %s | FileCheck %s
+
+// Tests that ByteAddressBuffer.Load<T> works with all type shapes with SM 6.2
+
+struct S { int i; float f; };
+ByteAddressBuffer buf;
+RWStructuredBuffer<int> out_scalar;
+RWStructuredBuffer<int2> out_vector;
+RWStructuredBuffer<int2x2> out_matrix;
+RWStructuredBuffer<int[2]> out_array;
+RWStructuredBuffer<S> out_struct;
+RWStructuredBuffer<S[2]> out_struct_array;
+
+void main() {
+  // CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle {{.*}}, i32 100, i32 undef, i8 1, i32 4)
+  out_scalar[0] = buf.Load<int>(100);
+  
+  // CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle {{.*}}, i32 200, i32 undef, i8 3, i32 4)
+  out_vector[0] = buf.Load<int2>(200);
+  
+  // CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle {{.*}}, i32 300, i32 undef, i8 15, i32 4)
+  out_matrix[0] = buf.Load<int2x2>(300);
+
+  // CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle {{.*}}, i32 400, i32 undef, i8 1, i32 4)
+  // CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle {{.*}}, i32 404, i32 undef, i8 1, i32 4)
+  out_array[0] = buf.Load<int[2]>(400);
+  
+  // CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle {{.*}}, i32 500, i32 undef, i8 1, i32 4)
+  // CHECK: call %dx.types.ResRet.f32 @dx.op.rawBufferLoad.f32(i32 139, %dx.types.Handle {{.*}}, i32 504, i32 undef, i8 1, i32 4)
+  out_struct[0] = buf.Load<S>(500);
+  
+  // Test loads of arrays of structs because of the SROA behavior that turns them into per-element arrays
+  // CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle {{.*}}, i32 600, i32 undef, i8 1, i32 4)
+  // CHECK: call %dx.types.ResRet.f32 @dx.op.rawBufferLoad.f32(i32 139, %dx.types.Handle {{.*}}, i32 604, i32 undef, i8 1, i32 4)
+  // CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle {{.*}}, i32 608, i32 undef, i8 1, i32 4)
+  // CHECK: call %dx.types.ResRet.f32 @dx.op.rawBufferLoad.f32(i32 139, %dx.types.Handle {{.*}}, i32 612, i32 undef, i8 1, i32 4)
+  out_struct_array[0] = buf.Load<S[2]>(600);
+}

+ 24 - 15
tools/clang/test/HLSL/intrinsic-examples.hlsl

@@ -8,6 +8,8 @@ float4 FetchFromIndexMap( uniform Texture2D Tex, uniform SamplerState SS, const
     return Sample * 255.0f;
     return Sample * 255.0f;
 }
 }
 
 
+struct S { float f; };
+
 RWByteAddressBuffer uav1 : register(u3);
 RWByteAddressBuffer uav1 : register(u3);
 float4 RWByteAddressBufferMain(uint2 a : A, uint2 b : B) : SV_Target
 float4 RWByteAddressBufferMain(uint2 a : A, uint2 b : B) : SV_Target
 {
 {
@@ -31,6 +33,9 @@ float4 RWByteAddressBufferMain(uint2 a : A, uint2 b : B) : SV_Target
   r += uav1.Load<int32_t3>(20).xyzx;
   r += uav1.Load<int32_t3>(20).xyzx;
   r += uav1.Load<float16_t>(20);
   r += uav1.Load<float16_t>(20);
   r += uav1.Load<float32_t1>(20);
   r += uav1.Load<float32_t1>(20);
+  r += (float4)uav1.Load<float2x2>(20);
+  r += (float4)uav1.Load<float[4]>(20);
+  r += uav1.Load<S>(20).f.xxxx;
 
 
   r += uav1.Load<half4>(4, status);
   r += uav1.Load<half4>(4, status);
   r += uav1.Load<float4>(12, status);
   r += uav1.Load<float4>(12, status);
@@ -38,20 +43,24 @@ float4 RWByteAddressBufferMain(uint2 a : A, uint2 b : B) : SV_Target
   r += uav1.Load<int32_t3>(20, status).xyzx;
   r += uav1.Load<int32_t3>(20, status).xyzx;
   r += uav1.Load<float16_t>(20, status);
   r += uav1.Load<float16_t>(20, status);
   r += uav1.Load<float32_t1>(20, status);
   r += uav1.Load<float32_t1>(20, status);
+  r += (float4)uav1.Load<float2x2>(20, status);
+  r += (float4)uav1.Load<float[4]>(20, status);
+  r += uav1.Load<S>(20, status).f.xxxx;
 
 
   // errors
   // errors
-  r += uav1.Load<float, float3>(16);                        /* expected-error {{Explicit template arguments on intrinsic Load are limited one to scalar or vector type.}} */
+  r += uav1.Load<float, float3>(16);                        /* expected-error {{Explicit template arguments on intrinsic Load must be a single numeric type}} */
   r += uav1.Load<double3>(16);                              /* expected-error {{cannot convert from 'double3' to 'float4'}} */
   r += uav1.Load<double3>(16);                              /* expected-error {{cannot convert from 'double3' to 'float4'}} */
-  r += uav1.Load2<float>(16);                               /* expected-error {{Explicit template arguments on intrinsic Load2 are not supported.}} */
-  r += uav1.Load3<int>(20);                                 /* expected-error {{Explicit template arguments on intrinsic Load3 are not supported.}} */
-  r += uav1.Load4<int16_t>(24);                             /* expected-error {{Explicit template arguments on intrinsic Load4 are not supported.}} */
-  r += uav1.Load<half3x4>(24);                              /* expected-error {{Explicit template arguments on intrinsic Load are limited one to scalar or vector type.}} expected-error {{cannot convert from 'matrix<half, 3, 4>' to 'float4'}} */
-  r += uav1.Load<float, float3>(16, status);                /* expected-error {{Explicit template arguments on intrinsic Load are limited one to scalar or vector type.}} */
+  r += uav1.Load2<float>(16);                               /* expected-error {{Explicit template arguments on intrinsic Load2 are not supported}} */
+  r += uav1.Load3<int>(20);                                 /* expected-error {{Explicit template arguments on intrinsic Load3 are not supported}} */
+  r += uav1.Load4<int16_t>(24);                             /* expected-error {{Explicit template arguments on intrinsic Load4 are not supported}} */
+  r += uav1.Load<half3x4>(24);                              /* expected-error {{cannot convert from 'half3x4' to 'float4'}} */
+  r += uav1.Load<float, float3>(16, status);                /* expected-error {{Explicit template arguments on intrinsic Load must be a single numeric type}} */
   r += uav1.Load<double3>(16, status);                      /* expected-error {{cannot convert from 'double3' to 'float4'}} */
   r += uav1.Load<double3>(16, status);                      /* expected-error {{cannot convert from 'double3' to 'float4'}} */
-  r += uav1.Load2<float>(16, status);                       /* expected-error {{Explicit template arguments on intrinsic Load2 are not supported.}} */
-  r += uav1.Load3<int>(20, status);                         /* expected-error {{Explicit template arguments on intrinsic Load3 are not supported.}} */
-  r += uav1.Load4<int16_t>(24, status);                     /* expected-error {{Explicit template arguments on intrinsic Load4 are not supported.}} */
-  r += uav1.Load<half3x4>(24, status);                      /* expected-error {{Explicit template arguments on intrinsic Load are limited one to scalar or vector type.}} expected-error {{cannot convert from 'matrix<half, 3, 4>' to 'float4'}} */
+  r += uav1.Load2<float>(16, status);                       /* expected-error {{Explicit template arguments on intrinsic Load2 are not supported}} */
+  r += uav1.Load3<int>(20, status);                         /* expected-error {{Explicit template arguments on intrinsic Load3 are not supported}} */
+  r += uav1.Load4<int16_t>(24, status);                     /* expected-error {{Explicit template arguments on intrinsic Load4 are not supported}} */
+  r += uav1.Load<half3x4>(24, status);                      /* expected-error {{cannot convert from 'half3x4' to 'float4'}} */
+
   // valid template argument
   // valid template argument
   uav1.Store(0, r);
   uav1.Store(0, r);
   uav1.Store(0, r.x);
   uav1.Store(0, r.x);
@@ -62,11 +71,11 @@ float4 RWByteAddressBufferMain(uint2 a : A, uint2 b : B) : SV_Target
   struct MyStruct {
   struct MyStruct {
     float4 x;
     float4 x;
   };
   };
-  uav1.Store<float>(0, r);                                  /* expected-error {{Explicit template arguments on intrinsic Store are not supported.}} */
-  uav1.Store<int64_t4>(0, r);                               /* expected-error {{Explicit template arguments on intrinsic Store are not supported.}} */
-  uav1.Store2<float>(0, r.xy);                              /* expected-error {{Explicit template arguments on intrinsic Store2 are not supported.}} */
-  uav1.Store3<float>(0, r.xyz);                             /* expected-error {{Explicit template arguments on intrinsic Store3 are not supported.}} */
-  uav1.Store4<float>(0, r);                                 /* expected-error {{Explicit template arguments on intrinsic Store4 are not supported.}} */
+  uav1.Store<float>(0, r);                                  /* expected-error {{Explicit template arguments on intrinsic Store are not supported}} */
+  uav1.Store<int64_t4>(0, r);                               /* expected-error {{Explicit template arguments on intrinsic Store are not supported}} */
+  uav1.Store2<float>(0, r.xy);                              /* expected-error {{Explicit template arguments on intrinsic Store2 are not supported}} */
+  uav1.Store3<float>(0, r.xyz);                             /* expected-error {{Explicit template arguments on intrinsic Store3 are not supported}} */
+  uav1.Store4<float>(0, r);                                 /* expected-error {{Explicit template arguments on intrinsic Store4 are not supported}} */
   uav1.Store(0, float2x4(1,2,3,4,5,6,7,8));                 /* expected-error {{no matching member function for call to 'Store'}} */
   uav1.Store(0, float2x4(1,2,3,4,5,6,7,8));                 /* expected-error {{no matching member function for call to 'Store'}} */
   uav1.Store<float3x2>(0, float3x2(1,2,3,4,5,6));           /* expected-error {{no matching member function for call to 'Store'}} */
   uav1.Store<float3x2>(0, float3x2(1,2,3,4,5,6));           /* expected-error {{no matching member function for call to 'Store'}} */
   uav1.Store(0, (double3)r.xyz);                            
   uav1.Store(0, (double3)r.xyz);