Browse Source

[spirv] Add test for struct accessing and assignment (#551)

Lei Zhang 8 years ago
parent
commit
1e6d05ac59

+ 29 - 13
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -1588,20 +1588,15 @@ uint32_t SPIRVEmitter::doInitListExpr(const InitListExpr *expr) {
 }
 
 uint32_t SPIRVEmitter::doMemberExpr(const MemberExpr *expr) {
-  const uint32_t base = doExpr(expr->getBase());
-  const auto *memberDecl = expr->getMemberDecl();
-  if (const auto *fieldDecl = dyn_cast<FieldDecl>(memberDecl)) {
-    const auto index = theBuilder.getConstantInt32(fieldDecl->getFieldIndex());
-    const uint32_t fieldType =
-        typeTranslator.translateType(fieldDecl->getType());
-    const uint32_t ptrType = theBuilder.getPointerType(
-        fieldType, declIdMapper.resolveStorageClass(expr->getBase()));
-    return theBuilder.createAccessChain(ptrType, base, {index});
-  }
+  llvm::SmallVector<uint32_t, 4> indices;
 
-  emitError("Decl '%0' in MemberExpr is not supported yet.")
-      << memberDecl->getDeclKindName();
-  return 0;
+  const Expr *baseExpr = collectStructIndices(expr, &indices);
+  const uint32_t base = doExpr(baseExpr);
+
+  const uint32_t fieldType = typeTranslator.translateType(expr->getType());
+  const uint32_t ptrType = theBuilder.getPointerType(
+      fieldType, declIdMapper.resolveStorageClass(baseExpr));
+  return theBuilder.createAccessChain(ptrType, base, indices);
 }
 
 uint32_t SPIRVEmitter::doUnaryOperator(const UnaryOperator *expr) {
@@ -2391,6 +2386,27 @@ uint32_t SPIRVEmitter::processMatrixBinaryOp(const Expr *lhs, const Expr *rhs,
   return 0;
 }
 
+const Expr *
+SPIRVEmitter::collectStructIndices(const MemberExpr *expr,
+                                   llvm::SmallVectorImpl<uint32_t> *indices) {
+  const Expr *base = expr->getBase();
+  if (const auto *memExpr = dyn_cast<MemberExpr>(base)) {
+    base = collectStructIndices(memExpr, indices);
+  } else {
+    indices->clear();
+  }
+
+  const auto *memberDecl = expr->getMemberDecl();
+  if (const auto *fieldDecl = dyn_cast<FieldDecl>(memberDecl)) {
+    indices->push_back(theBuilder.getConstantInt32(fieldDecl->getFieldIndex()));
+  } else {
+    emitError("Decl '%0' in MemberExpr is not supported yet.")
+        << memberDecl->getDeclKindName();
+  }
+
+  return base;
+}
+
 uint32_t SPIRVEmitter::castToBool(const uint32_t fromVal, QualType fromType,
                                   QualType toBoolType) {
   if (isSameScalarOrVecType(fromType, toBoolType))

+ 6 - 0
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -210,6 +210,12 @@ private:
   uint32_t processMatrixBinaryOp(const Expr *lhs, const Expr *rhs,
                                  const BinaryOperatorKind opcode);
 
+  /// Collects all indices (SPIR-V constant values) from consecutive MemberExprs
+  /// and writes into indices. Returns the real base (the first Expr that is not
+  /// a MemberExpr).
+  const Expr *collectStructIndices(const MemberExpr *expr,
+                                   llvm::SmallVectorImpl<uint32_t> *indices);
+
 private:
   /// Processes the given expr, casts the result into the given bool (vector)
   /// type and returns the <result-id> of the casted value.

+ 19 - 1
tools/clang/test/CodeGenSPIRV/binary-op.assign.hlsl

@@ -1,6 +1,13 @@
 // Run: %dxc -T ps_6_0 -E main
 
-// TODO: assignment for composite types
+struct S {
+    float x;
+};
+
+struct T {
+    float y;
+    S z;
+};
 
 void main() {
     int a, b, c;
@@ -20,4 +27,15 @@ void main() {
 // CHECK-NEXT: OpStore %a [[a1]]
 // CHECK-NEXT: OpStore %a [[a1]]
     a = a = a;
+
+    T p, q;
+
+// CHECK-NEXT: [[q:%\d+]] = OpLoad %T %q
+// CHECK-NEXT: OpStore %p [[q]]
+    p = q;     // assign as a whole
+// CHECK-NEXT: [[q1ptr:%\d+]] = OpAccessChain %_ptr_Function_S %q %int_1
+// CHECK-NEXT: [[q1val:%\d+]] = OpLoad %S [[q1ptr]]
+// CHECK-NEXT: [[p1ptr:%\d+]] = OpAccessChain %_ptr_Function_S %p %int_1
+// CHECK-NEXT: OpStore [[p1ptr]] [[q1val]]
+    p.z = q.z; // assign nested struct
 }

+ 75 - 0
tools/clang/test/CodeGenSPIRV/op.struct.access.hlsl

@@ -0,0 +1,75 @@
+// Run: %dxc -T vs_6_0 -E main
+
+struct S {
+    bool a;
+    uint2 b;
+    float2x3 c;
+};
+
+struct T {
+    int h; // Nested struct
+    S i;
+};
+
+void main() {
+    T t;
+
+// CHECK:      [[h:%\d+]] = OpAccessChain %_ptr_Function_int %t %int_0
+// CHECK-NEXT: {{%\d+}} = OpLoad %int [[h]]
+    int v1 = t.h;
+// CHECK:      [[a:%\d+]] = OpAccessChain %_ptr_Function_bool %t %int_1 %int_0
+// CHECK-NEXT: {{%\d+}} = OpLoad %bool [[a]]
+    bool v2 = t.i.a;
+
+// CHECK:      [[b:%\d+]] = OpAccessChain %_ptr_Function_v2uint %t %int_1 %int_1
+// CHECK-NEXT: [[b0:%\d+]] = OpAccessChain %_ptr_Function_uint [[b]] %uint_0
+// CHECK-NEXT: {{%\d+}} = OpLoad %uint [[b0]]
+    uint v3 = t.i.b[0];
+// CHECK:      [[b:%\d+]] = OpAccessChain %_ptr_Function_v2uint %t %int_1 %int_1
+// CHECK-NEXT: {{%\d+}} = OpLoad %v2uint [[b]]
+    uint2 v4 = t.i.b.rg;
+
+// CHECK:      [[c:%\d+]] = OpAccessChain %_ptr_Function_mat2v3float %t %int_1 %int_2
+// CHECK-NEXT: [[c00p:%\d+]] = OpAccessChain %_ptr_Function_float [[c]] %int_0 %int_0
+// CHECK-NEXT: [[c00v:%\d+]] = OpLoad %float [[c00p]]
+// CHECK-NEXT: [[c11p:%\d+]] = OpAccessChain %_ptr_Function_float [[c]] %int_1 %int_1
+// CHECK-NEXT: [[c11v:%\d+]] = OpLoad %float [[c11p]]
+// CHECK-NEXT: {{%\d+}} = OpCompositeConstruct %v2float [[c00v]] [[c11v]]
+    float2 v5 = t.i.c._11_22;
+// CHECK:      [[c:%\d+]] = OpAccessChain %_ptr_Function_mat2v3float %t %int_1 %int_2
+// CHECK-NEXT: [[c1:%\d+]] = OpAccessChain %_ptr_Function_v3float [[c]] %uint_1
+// CHECK-NEXT: {{%\d+}} = OpLoad %v3float [[c1]]
+    float3 v6 = t.i.c[1];
+
+// CHECK:      [[h:%\d+]] = OpAccessChain %_ptr_Function_int %t %int_0
+// CHECK-NEXT: OpStore [[h]] {{%\d+}}
+    t.h = v1;
+// CHECK:      [[a:%\d+]] = OpAccessChain %_ptr_Function_bool %t %int_1 %int_0
+// CHECK-NEXT: OpStore [[a]] {{%\d+}}
+    t.i.a = v2;
+
+// CHECK:      [[b:%\d+]] = OpAccessChain %_ptr_Function_v2uint %t %int_1 %int_1
+// CHECK-NEXT: [[b1:%\d+]] = OpAccessChain %_ptr_Function_uint [[b]] %uint_1
+// CHECK-NEXT: OpStore [[b1]] {{%\d+}}
+    t.i.b[1] = v3;
+// CHECK:      [[v4:%\d+]] = OpLoad %v2uint %v4
+// CHECK-NEXT: [[b:%\d+]] = OpAccessChain %_ptr_Function_v2uint %t %int_1 %int_1
+// CHECK-NEXT: [[bv:%\d+]] = OpLoad %v2uint [[b]]
+// CHECK-NEXT: [[gr:%\d+]] = OpVectorShuffle %v2uint [[bv]] [[v4]] 3 2
+// CHECK-NEXT: OpStore [[b]] [[gr]]
+    t.i.b.gr = v4;
+
+// CHECK:      [[v5:%\d+]] = OpLoad %v2float %v5
+// CHECK-NEXT: [[c:%\d+]] = OpAccessChain %_ptr_Function_mat2v3float %t %int_1 %int_2
+// CHECK-NEXT: [[v50:%\d+]] = OpCompositeExtract %float [[v5]] 0
+// CHECK-NEXT: [[c11:%\d+]] = OpAccessChain %_ptr_Function_float [[c]] %int_1 %int_1
+// CHECK-NEXT: OpStore [[c11]] [[v50]]
+// CHECK-NEXT: [[v51:%\d+]] = OpCompositeExtract %float [[v5]] 1
+// CHECK-NEXT: [[c00:%\d+]] = OpAccessChain %_ptr_Function_float [[c]] %int_0 %int_0
+// CHECK-NEXT: OpStore [[c00]] [[v51]]
+    t.i.c._22_11 = v5;
+// CHECK:      [[c:%\d+]] = OpAccessChain %_ptr_Function_mat2v3float %t %int_1 %int_2
+// CHECK-NEXT: [[c0:%\d+]] = OpAccessChain %_ptr_Function_v3float [[c]] %uint_0
+// CHECK-NEXT: OpStore [[c0]] {{%\d+}}
+    t.i.c[0] = v6;
+}

+ 3 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -178,6 +178,9 @@ TEST_F(FileTest, OpMatrixAccess1x1) {
   runFileTest("op.matrix.access.1x1.hlsl");
 }
 
+// For struct accessing operator
+TEST_F(FileTest, OpStructAccess) { runFileTest("op.struct.access.hlsl"); }
+
 // For casting
 TEST_F(FileTest, CastNoOp) { runFileTest("cast.no-op.hlsl"); }
 TEST_F(FileTest, CastImplicit2Bool) { runFileTest("cast.2bool.implicit.hlsl"); }