ソースを参照

[spirv] Buffer/RWbuffer Load() and operator[] (#589)

* [spirv] Support Load function for Buffer/RWBuffer.

* [spirv] Operator[] for Buffer/RWBuffer.

Also some improvement for handling of vec2 and vec3 cases:
use OpVectorShuffle instead of OpCompositeExtract&OpCompositeConstruct.

* Address code review comments.
Ehsan 8 年 前
コミット
2b5f9ffe1d

+ 92 - 6
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -1111,13 +1111,17 @@ uint32_t SPIRVEmitter::doCastExpr(const CastExpr *expr) {
   switch (expr->getCastKind()) {
   case CastKind::CK_LValueToRValue: {
     const uint32_t fromValue = doExpr(subExpr);
-    if (isVectorShuffle(subExpr) || isa<ExtMatrixElementExpr>(subExpr)) {
-      // By reaching here, it means the vector/matrix element accessing
-      // operation is an lvalue. For vector element accessing, if we generated
-      // a vector shuffle for it and trying to use it as a rvalue, we cannot
-      // do the load here as normal. Need the upper nodes in the AST tree to
-      // handle it properly. For matrix element accessing, load should have
+    if (isVectorShuffle(subExpr) || isa<ExtMatrixElementExpr>(subExpr) ||
+        isBufferIndexing(dyn_cast<CXXOperatorCallExpr>(subExpr))) {
+      // By reaching here, it means the vector/matrix/Buffer/RWBuffer element
+      // accessing operation is an lvalue. For vector element accessing, if we
+      // generated a vector shuffle for it and trying to use it as a rvalue, we
+      // cannot do the load here as normal. Need the upper nodes in the AST tree
+      // to handle it properly. For matrix element accessing, load should have
       // already happened after creating access chain for each element.
+      // For (RW)Buffer element accessing, load should have already happened
+      // using OpImageFetch.
+
       return fromValue;
     }
 
@@ -1288,6 +1292,54 @@ uint32_t SPIRVEmitter::doConditionalOperator(const ConditionalOperator *expr) {
   return theBuilder.createSelect(type, condition, trueBranch, falseBranch);
 }
 
+uint32_t SPIRVEmitter::processBufferLoad(const Expr *object,
+                                         const Expr *location) {
+  // Loading for Buffer and RWBuffer translates to an OpImageFetch.
+  // The result type of an OpImageFetch must be a vec4 of float or int.
+  const auto type = object->getType();
+  const uint32_t objectId = doExpr(object);
+  const uint32_t locationId = doExpr(location);
+  const auto sampledType = hlsl::GetHLSLResourceResultType(type);
+  QualType elemType = sampledType;
+  uint32_t elemCount = 1;
+  uint32_t elemTypeId = 0;
+  (void)TypeTranslator::isVectorType(sampledType, &elemType, &elemCount);
+  if (elemType->isFloatingType()) {
+    elemTypeId = theBuilder.getFloat32Type();
+  } else if (elemType->isSignedIntegerType()) {
+    elemTypeId = theBuilder.getInt32Type();
+  } else if (elemType->isUnsignedIntegerType()) {
+    elemTypeId = theBuilder.getUint32Type();
+  } else {
+    emitError("Unimplemented Buffer type");
+    return 0;
+  }
+  const uint32_t resultTypeId =
+      elemCount == 1 ? elemTypeId
+                     : theBuilder.getVecType(elemTypeId, elemCount);
+
+  // Always need to fetch 4 elements.
+  const uint32_t fetchTypeId = theBuilder.getVecType(elemTypeId, 4u);
+  const uint32_t imageFetchResult =
+      theBuilder.createImageFetch(fetchTypeId, objectId, locationId, 0, 0, 0);
+
+  // For the case of buffer elements being vec4, there's no need for extraction
+  // and composition.
+  switch (elemCount) {
+  case 1:
+    return theBuilder.createCompositeExtract(elemTypeId, imageFetchResult, {0});
+  case 2:
+    return theBuilder.createVectorShuffle(resultTypeId, imageFetchResult,
+                                          imageFetchResult, {0, 1});
+  case 3:
+    return theBuilder.createVectorShuffle(resultTypeId, imageFetchResult,
+                                          imageFetchResult, {0, 1, 2});
+  case 4:
+    return imageFetchResult;
+  }
+  llvm_unreachable("Element count of a vector must be 1, 2, 3, or 4.");
+}
+
 uint32_t SPIRVEmitter::processByteAddressBufferLoadStore(
     const CXXMemberCallExpr *expr, uint32_t numWords, bool doStore) {
   uint32_t resultId = 0;
@@ -1517,6 +1569,10 @@ uint32_t SPIRVEmitter::doCXXMemberCallExpr(const CXXMemberCallExpr *expr) {
           typeTranslator.isByteAddressBuffer(objectType)) {
         return processByteAddressBufferLoadStore(expr, 1, /*doStore*/ false);
       }
+      if (TypeTranslator::isBuffer(objectType) ||
+          TypeTranslator::isRWBuffer(objectType))
+        return processBufferLoad(expr->getImplicitObjectArgument(),
+                                 expr->getArg(0));
 
       const uint32_t image = loadIfGLValue(object);
 
@@ -1622,6 +1678,15 @@ uint32_t SPIRVEmitter::doCXXOperatorCallExpr(const CXXOperatorCallExpr *expr) {
     }
   }
 
+  { // Handle Buffer/RWBuffer indexing
+    const Expr *baseExpr = nullptr;
+    const Expr *indexExpr = nullptr;
+
+    if (isBufferIndexing(expr, &baseExpr, &indexExpr)) {
+      return processBufferLoad(baseExpr, indexExpr);
+    }
+  }
+
   emitError("unimplemented C++ operator call: %0") << expr->getOperator();
   expr->dump();
   return 0;
@@ -2126,6 +2191,27 @@ bool SPIRVEmitter::isVectorShuffle(const Expr *expr) {
   return false;
 }
 
+bool SPIRVEmitter::isBufferIndexing(const CXXOperatorCallExpr *indexExpr,
+                                    const Expr **base, const Expr **index) {
+  if (!indexExpr)
+    return false;
+
+  // Must be operator[]
+  if (indexExpr->getOperator() != OverloadedOperatorKind::OO_Subscript)
+    return false;
+  const Expr *object = indexExpr->getArg(0);
+  const auto objectType = object->getType();
+  if (typeTranslator.isBuffer(objectType) ||
+      typeTranslator.isRWBuffer(objectType)) {
+    if (base)
+      *base = object;
+    if (index)
+      *index = indexExpr->getArg(1);
+    return true;
+  }
+  return false;
+}
+
 bool SPIRVEmitter::isVecMatIndexing(const CXXOperatorCallExpr *vecIndexExpr,
                                     const Expr **base, const Expr **index0,
                                     const Expr **index1) {

+ 13 - 0
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -153,6 +153,14 @@ private:
                         const Expr **base, const Expr **index0,
                         const Expr **index1);
 
+  /// \brief Returns true if the given CXXOperatorCallExpr is indexing into a
+  /// Buffer/RWBuffer using operator[].
+  /// On success, writes the base buffer into *base if base is not nullptr, and
+  /// writes the index into *index if index is not nullptr.
+  bool isBufferIndexing(const CXXOperatorCallExpr *,
+                        const Expr **base = nullptr,
+                        const Expr **index = nullptr);
+
   /// Condenses a sequence of HLSLVectorElementExpr starting from the given
   /// expr into one. Writes the original base into *basePtr and the condensed
   /// accessor into *flattenedAccessor.
@@ -424,6 +432,11 @@ private:
   uint32_t processByteAddressBufferLoadStore(const CXXMemberCallExpr *,
                                              uint32_t numWords, bool doStore);
 
+  /// \brief Loads one element from the given Buffer/RWBuffer object at the
+  /// given location. The type of the loaded element matches the type in the
+  /// declaration for the (RW)Buffer object.
+  uint32_t processBufferLoad(const Expr *object, const Expr *address);
+
 private:
   /// \brief Wrapper method to create an error message and report it
   /// in the diagnostic engine associated with this consumer.

+ 14 - 0
tools/clang/lib/SPIRV/TypeTranslator.cpp

@@ -211,6 +211,20 @@ bool TypeTranslator::isByteAddressBuffer(QualType type) {
   return false;
 }
 
+bool TypeTranslator::isRWBuffer(QualType type) {
+  if (const auto *rt = type->getAs<RecordType>()) {
+    return rt->getDecl()->getName() == "RWBuffer";
+  }
+  return false;
+}
+
+bool TypeTranslator::isBuffer(QualType type) {
+  if (const auto *rt = type->getAs<RecordType>()) {
+    return rt->getDecl()->getName() == "Buffer";
+  }
+  return false;
+}
+
 bool TypeTranslator::isVectorType(QualType type, QualType *elemType,
                                   uint32_t *elemCount) {
   bool isVec = false;

+ 6 - 0
tools/clang/lib/SPIRV/TypeTranslator.h

@@ -52,6 +52,12 @@ public:
   /// \brief Returns true if the given type is the HLSL RWByteAddressBufferType.
   bool isRWByteAddressBuffer(QualType type);
 
+  /// \brief Returns true if the given type is the HLSL Buffer type.
+  static bool isBuffer(QualType type);
+
+  /// \brief Returns true if the given type is the HLSL RWBuffer type.
+  static bool isRWBuffer(QualType type);
+
   /// \brief Returns true if the given type will be translated into a SPIR-V
   /// scalar type. This includes normal scalar types, vectors of size 1, and
   /// 1x1 matrices. If scalarType is not nullptr, writes the scalar type to

+ 66 - 0
tools/clang/test/CodeGenSPIRV/buffer.load.hlsl

@@ -0,0 +1,66 @@
+// Run: %dxc -T ps_6_0 -E main
+
+Buffer<int> intbuf;
+Buffer<uint> uintbuf;
+Buffer<float> floatbuf;
+RWBuffer<int2> int2buf;
+RWBuffer<uint2> uint2buf;
+RWBuffer<float2> float2buf;
+Buffer<int3> int3buf;
+Buffer<uint3> uint3buf;
+Buffer<float3> float3buf;
+RWBuffer<int4> int4buf;
+RWBuffer<uint4> uint4buf;
+RWBuffer<float4> float4buf;
+
+void main() {
+  int address;
+
+// CHECK:      [[f1:%\d+]] = OpImageFetch %v4int %intbuf {{%\d+}} None
+// CHECK-NEXT: {{%\d+}} = OpCompositeExtract %int [[f1]] 0
+  int int1 = intbuf.Load(address);
+
+// CHECK:      [[f2:%\d+]] = OpImageFetch %v4uint %uintbuf {{%\d+}} None
+// CHECK-NEXT: {{%\d+}} = OpCompositeExtract %uint [[f2]] 0
+  uint uint1 = uintbuf.Load(address);
+
+// CHECK:      [[f3:%\d+]] = OpImageFetch %v4float %floatbuf {{%\d+}} None
+// CHECK-NEXT: {{%\d+}} = OpCompositeExtract %float [[f3]] 0
+  float float1 = floatbuf.Load(address);
+
+// CHECK:      [[f4:%\d+]] = OpImageFetch %v4int %int2buf {{%\d+}} None
+// CHECK-NEXT: {{%\d+}} = OpVectorShuffle %v2int [[f4]] [[f4]] 0 1
+  int2 int2 = int2buf.Load(address);
+
+// CHECK:      [[f5:%\d+]] = OpImageFetch %v4uint %uint2buf {{%\d+}} None
+// CHECK-NEXT: {{%\d+}} = OpVectorShuffle %v2uint [[f5]] [[f5]] 0 1
+  uint2 uint2 = uint2buf.Load(address);
+
+// CHECK:      [[f6:%\d+]] = OpImageFetch %v4float %float2buf {{%\d+}} None
+// CHECK-NEXT: {{%\d+}} = OpVectorShuffle %v2float [[f6]] [[f6]] 0 1
+  float2 float2 = float2buf.Load(address);
+
+// CHECK:      [[f7:%\d+]] = OpImageFetch %v4int %int3buf {{%\d+}} None
+// CHECK-NEXT: {{%\d+}} = OpVectorShuffle %v3int [[f7]] [[f7]] 0 1 2
+  int3 int3 = int3buf.Load(address);
+
+// CHECK:      [[f8:%\d+]] = OpImageFetch %v4uint %uint3buf {{%\d+}} None
+// CHECK-NEXT: {{%\d+}} = OpVectorShuffle %v3uint [[f8]] [[f8]] 0 1 2
+  uint3 uint3 = uint3buf.Load(address);
+
+// CHECK:      [[f9:%\d+]] = OpImageFetch %v4float %float3buf {{%\d+}} None
+// CHECK-NEXT: {{%\d+}} = OpVectorShuffle %v3float [[f9]] [[f9]] 0 1 2
+  float3 float3 = float3buf.Load(address);
+
+// CHECK:      {{%\d+}} = OpImageFetch %v4int %int4buf {{%\d+}} None
+// CHECK-NEXT: OpStore %int4 {{%\d+}}
+  int4 int4 = int4buf.Load(address);
+
+// CHECK:      {{%\d+}} = OpImageFetch %v4uint %uint4buf {{%\d+}} None
+// CHECK-NEXT: OpStore %uint4 {{%\d+}}
+  uint4 uint4 = uint4buf.Load(address);
+
+// CHECK:      {{%\d+}} = OpImageFetch %v4float %float4buf {{%\d+}} None
+// CHECK-NEXT: OpStore %float4 {{%\d+}}
+  float4 float4 = float4buf.Load(address);
+}

+ 75 - 0
tools/clang/test/CodeGenSPIRV/op.buffer.access.hlsl

@@ -0,0 +1,75 @@
+// Run: %dxc -T ps_6_0 -E main
+
+Buffer<int> intbuf;
+Buffer<uint> uintbuf;
+Buffer<float> floatbuf;
+RWBuffer<int2> int2buf;
+RWBuffer<uint2> uint2buf;
+RWBuffer<float2> float2buf;
+Buffer<int3> int3buf;
+Buffer<uint3> uint3buf;
+Buffer<float3> float3buf;
+RWBuffer<int4> int4buf;
+RWBuffer<uint4> uint4buf;
+RWBuffer<float4> float4buf;
+
+void main() {
+  int address;
+
+// CHECK:      [[f1:%\d+]] = OpImageFetch %v4int %intbuf {{%\d+}} None
+// CHECK-NEXT: [[r1:%\d+]] = OpCompositeExtract %int [[f1]] 0
+// CHECK-NEXT: OpStore %int1 [[r1]]
+  int int1 = intbuf[address];
+
+// CHECK:      [[f2:%\d+]] = OpImageFetch %v4uint %uintbuf {{%\d+}} None
+// CHECK-NEXT: [[r2:%\d+]] = OpCompositeExtract %uint [[f2]] 0
+// CHECK-NEXT: OpStore %uint1 [[r2]]
+  uint uint1 = uintbuf[address];
+
+// CHECK:      [[f3:%\d+]] = OpImageFetch %v4float %floatbuf {{%\d+}} None
+// CHECK-NEXT: [[r3:%\d+]] = OpCompositeExtract %float [[f3]] 0
+// CHECK-NEXT: OpStore %float1 [[r3]]
+  float float1 = floatbuf[address];
+
+// CHECK:      [[f4:%\d+]] = OpImageFetch %v4int %int2buf {{%\d+}} None
+// CHECK-NEXT: [[r4:%\d+]] = OpVectorShuffle %v2int [[f4]] [[f4]] 0 1
+// CHECK-NEXT: OpStore %int2 [[r4]]
+  int2 int2 = int2buf[address];
+
+// CHECK:      [[f5:%\d+]] = OpImageFetch %v4uint %uint2buf {{%\d+}} None
+// CHECK-NEXT: [[r5:%\d+]] = OpVectorShuffle %v2uint [[f5]] [[f5]] 0 1
+// CHECK-NEXT: OpStore %uint2 [[r5]]
+  uint2 uint2 = uint2buf[address];
+
+// CHECK:      [[f6:%\d+]] = OpImageFetch %v4float %float2buf {{%\d+}} None
+// CHECK-NEXT: [[r6:%\d+]] = OpVectorShuffle %v2float [[f6]] [[f6]] 0 1
+// CHECK-NEXT: OpStore %float2 [[r6]]
+  float2 float2 = float2buf[address];
+
+// CHECK:      [[f7:%\d+]] = OpImageFetch %v4int %int3buf {{%\d+}} None
+// CHECK-NEXT: [[r7:%\d+]] = OpVectorShuffle %v3int [[f7]] [[f7]] 0 1 2
+// CHECK-NEXT: OpStore %int3 [[r7]]
+  int3 int3 = int3buf[address];
+
+// CHECK:      [[f8:%\d+]] = OpImageFetch %v4uint %uint3buf {{%\d+}} None
+// CHECK-NEXT: [[r8:%\d+]] = OpVectorShuffle %v3uint [[f8]] [[f8]] 0 1 2
+// CHECK-NEXT: OpStore %uint3 [[r8]]
+  uint3 uint3 = uint3buf[address];
+
+// CHECK:      [[f9:%\d+]] = OpImageFetch %v4float %float3buf {{%\d+}} None
+// CHECK-NEXT: [[r9:%\d+]] = OpVectorShuffle %v3float [[f9]] [[f9]] 0 1 2
+// CHECK-NEXT: OpStore %float3 [[r9]]
+  float3 float3 = float3buf[address];
+
+// CHECK:      [[r10:%\d+]] = OpImageFetch %v4int %int4buf {{%\d+}} None
+// CHECK-NEXT: OpStore %int4 [[r10]]
+  int4 int4 = int4buf[address];
+
+// CHECK:      [[r11:%\d+]] = OpImageFetch %v4uint %uint4buf {{%\d+}} None
+// CHECK-NEXT: OpStore %uint4 [[r11]]
+  uint4 uint4 = uint4buf[address];
+
+// CHECK:      [[r12:%\d+]] = OpImageFetch %v4float %float4buf {{%\d+}} None
+// CHECK-NEXT: OpStore %float4 [[r12]]
+  float4 float4 = float4buf[address];
+}

+ 6 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -197,6 +197,9 @@ TEST_F(FileTest, OpStructAccess) { runFileTest("op.struct.access.hlsl"); }
 TEST_F(FileTest, OpCBufferAccess) { runFileTest("op.cbuffer.access.hlsl"); }
 TEST_F(FileTest, OpStructArray) { runFileTest("op.array.access.hlsl"); }
 
+// For Buffer/RWBuffer accessing operator
+TEST_F(FileTest, OpBufferAccess) { runFileTest("op.buffer.access.hlsl"); }
+
 // For casting
 TEST_F(FileTest, CastNoOp) { runFileTest("cast.no-op.hlsl"); }
 TEST_F(FileTest, CastImplicit2Bool) { runFileTest("cast.2bool.implicit.hlsl"); }
@@ -343,6 +346,9 @@ TEST_F(FileTest, ByteAddressBufferStore) {
   runFileTest("method.byte-address-buffer.store.hlsl");
 }
 
+// For Buffer/RWBuffer methods
+TEST_F(FileTest, BufferLoad) { runFileTest("buffer.load.hlsl"); }
+
 // For intrinsic functions
 TEST_F(FileTest, IntrinsicsDot) { runFileTest("intrinsics.dot.hlsl"); }
 TEST_F(FileTest, IntrinsicsMul) { runFileTest("intrinsics.mul.hlsl"); }