Przeglądaj źródła

[spirv] Add tests for struct types and emit struct debug names (#542)

Lei Zhang 8 lat temu
rodzic
commit
b529d22725

+ 3 - 1
tools/clang/include/clang/SPIRV/ModuleBuilder.h

@@ -241,7 +241,9 @@ public:
   uint32_t getVecType(uint32_t elemType, uint32_t elemCount);
   uint32_t getMatType(uint32_t colType, uint32_t colCount);
   uint32_t getPointerType(uint32_t pointeeType, spv::StorageClass);
-  uint32_t getStructType(llvm::ArrayRef<uint32_t> fieldTypes);
+  uint32_t getStructType(llvm::ArrayRef<uint32_t> fieldTypes,
+                         llvm::StringRef structName = "",
+                         llvm::ArrayRef<llvm::StringRef> fieldNames = {});
   uint32_t getFunctionType(uint32_t returnType,
                            llvm::ArrayRef<uint32_t> paramTypes);
 

+ 3 - 1
tools/clang/include/clang/SPIRV/SPIRVContext.h

@@ -59,7 +59,9 @@ public:
 
   /// \brief Returns the <result-id> that defines the given Type. If the type
   /// has not been defined, it will define and store its instruction.
-  uint32_t getResultIdForType(const Type *);
+  /// If isRegistered is not nullptr, *isRegistered will contain whether the
+  /// type was previously seen.
+  uint32_t getResultIdForType(const Type *type, bool *isRegistered = nullptr);
 
   /// \brief Returns the <result-id> that defines the given Constant. If the
   /// constant has not been defined, it will define and return its result-id.

+ 17 - 2
tools/clang/lib/SPIRV/ModuleBuilder.cpp

@@ -409,10 +409,25 @@ uint32_t ModuleBuilder::getPointerType(uint32_t pointeeType,
   return typeId;
 }
 
-uint32_t ModuleBuilder::getStructType(llvm::ArrayRef<uint32_t> fieldTypes) {
+uint32_t
+ModuleBuilder::getStructType(llvm::ArrayRef<uint32_t> fieldTypes,
+                             llvm::StringRef structName,
+                             llvm::ArrayRef<llvm::StringRef> fieldNames) {
   const Type *type = Type::getStruct(theContext, fieldTypes);
-  const uint32_t typeId = theContext.getResultIdForType(type);
+  bool isRegistered = false;
+  const uint32_t typeId = theContext.getResultIdForType(type, &isRegistered);
   theModule.addType(type, typeId);
+  // TODO: Probably we should check duplication and do nothing if trying to add
+  // the same debug name for the same entity in addDebugName().
+  if (!isRegistered) {
+    theModule.addDebugName(typeId, structName);
+    if (!fieldNames.empty()) {
+      assert(fieldNames.size() == fieldTypes.size());
+      for (uint32_t i = 0; i < fieldNames.size(); ++i)
+        theModule.addDebugName(typeId, fieldNames[i],
+                               llvm::Optional<uint32_t>(i));
+    }
+  }
   return typeId;
 }
 

+ 5 - 1
tools/clang/lib/SPIRV/SPIRVContext.cpp

@@ -15,7 +15,7 @@
 namespace clang {
 namespace spirv {
 
-uint32_t SPIRVContext::getResultIdForType(const Type *t) {
+uint32_t SPIRVContext::getResultIdForType(const Type *t, bool *isRegistered) {
   assert(t != nullptr);
   uint32_t result_id = 0;
 
@@ -24,8 +24,12 @@ uint32_t SPIRVContext::getResultIdForType(const Type *t) {
     // The Type has not been defined yet. Reserve an ID for it.
     result_id = takeNextId();
     typeResultIdMap[t] = result_id;
+    if (isRegistered)
+      *isRegistered = false;
   } else {
     result_id = iter->second;
+    if (isRegistered)
+      *isRegistered = true;
   }
 
   assert(result_id != 0);

+ 18 - 15
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -242,21 +242,7 @@ uint32_t SPIRVEmitter::doExpr(const Expr *expr) {
   }
 
   if (const auto *memberExpr = dyn_cast<MemberExpr>(expr)) {
-    const uint32_t base = doExpr(memberExpr->getBase());
-    const auto *memberDecl = memberExpr->getMemberDecl();
-    if (const auto *fieldDecl = dyn_cast<FieldDecl>(memberDecl)) {
-      const auto index =
-          theBuilder.getConstantInt32(fieldDecl->getFieldIndex());
-      const uint32_t fieldType =
-          typeTranslator.translateType(fieldDecl->getType());
-      const uint32_t ptrType = theBuilder.getPointerType(
-          fieldType, declIdMapper.resolveStorageClass(memberExpr->getBase()));
-      return theBuilder.createAccessChain(ptrType, base, {index});
-    } else {
-      emitError("Decl '%0' in MemberExpr is not supported yet.")
-          << memberDecl->getDeclKindName();
-      return 0;
-    }
+    return doMemberExpr(memberExpr);
   }
 
   if (const auto *castExpr = dyn_cast<CastExpr>(expr)) {
@@ -1400,6 +1386,23 @@ uint32_t SPIRVEmitter::doInitListExpr(const InitListExpr *expr) {
   return InitListHandler(*this).process(expr);
 }
 
+uint32_t SPIRVEmitter::doMemberExpr(const MemberExpr *expr) {
+  const uint32_t base = doExpr(expr->getBase());
+  const auto *memberDecl = expr->getMemberDecl();
+  if (const auto *fieldDecl = dyn_cast<FieldDecl>(memberDecl)) {
+    const auto index = theBuilder.getConstantInt32(fieldDecl->getFieldIndex());
+    const uint32_t fieldType =
+        typeTranslator.translateType(fieldDecl->getType());
+    const uint32_t ptrType = theBuilder.getPointerType(
+        fieldType, declIdMapper.resolveStorageClass(expr->getBase()));
+    return theBuilder.createAccessChain(ptrType, base, {index});
+  }
+
+  emitError("Decl '%0' in MemberExpr is not supported yet.")
+      << memberDecl->getDeclKindName();
+  return 0;
+}
+
 uint32_t SPIRVEmitter::doUnaryOperator(const UnaryOperator *expr) {
   const auto opcode = expr->getOpcode();
   const auto *subExpr = expr->getSubExpr();

+ 1 - 0
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -92,6 +92,7 @@ private:
   uint32_t doExtMatrixElementExpr(const ExtMatrixElementExpr *expr);
   uint32_t doHLSLVectorElementExpr(const HLSLVectorElementExpr *expr);
   uint32_t doInitListExpr(const InitListExpr *expr);
+  uint32_t doMemberExpr(const MemberExpr *expr);
   uint32_t doUnaryOperator(const UnaryOperator *expr);
 
 private:

+ 5 - 3
tools/clang/lib/SPIRV/TypeTranslator.cpp

@@ -100,13 +100,15 @@ uint32_t TypeTranslator::translateType(QualType type) {
   if (const auto *structType = dyn_cast<RecordType>(typePtr)) {
     const auto *decl = structType->getDecl();
 
-    // Collect all fields' types.
-    std::vector<uint32_t> fieldTypes;
+    // Collect all fields' types and names.
+    llvm::SmallVector<uint32_t, 4> fieldTypes;
+    llvm::SmallVector<llvm::StringRef, 4> fieldNames;
     for (const auto *field : decl->fields()) {
       fieldTypes.push_back(translateType(field->getType()));
+      fieldNames.push_back(field->getName());
     }
 
-    return theBuilder.getStructType(fieldTypes);
+    return theBuilder.getStructType(fieldTypes, type.getAsString(), fieldNames);
   }
 
   emitError("Type '%0' is not supported yet.") << type->getTypeClassName();

+ 6 - 3
tools/clang/test/CodeGenSPIRV/passthru-vs.hlsl2spv

@@ -24,6 +24,9 @@ PSInput VSmain(float4 position: POSITION, float4 color: COLOR) {
 // OpEntryPoint Vertex %VSmain "VSmain" %gl_Position %5 %7 %8
 // OpName %VSmain "VSmain"
 // OpName %bb_entry "bb.entry"
+// OpName %PSInput "PSInput"
+// OpMemberName %PSInput 0 "position"
+// OpMemberName %PSInput 1 "color"
 // OpName %result "result"
 // OpDecorate %gl_Position BuiltIn Position
 // OpDecorate %5 Location 0
@@ -36,8 +39,8 @@ PSInput VSmain(float4 position: POSITION, float4 color: COLOR) {
 // %_ptr_Input_v4float = OpTypePointer Input %v4float
 // %void = OpTypeVoid
 // %10 = OpTypeFunction %void
-// %_struct_13 = OpTypeStruct %v4float %v4float
-// %_ptr_Function__struct_13 = OpTypePointer Function %_struct_13
+// %PSInput = OpTypeStruct %v4float %v4float
+// %_ptr_Function_PSInput = OpTypePointer Function %PSInput
 // %_ptr_Function_v4float = OpTypePointer Function %v4float
 // %int_0 = OpConstant %int 0
 // %int_1 = OpConstant %int 1
@@ -47,7 +50,7 @@ PSInput VSmain(float4 position: POSITION, float4 color: COLOR) {
 // %8 = OpVariable %_ptr_Input_v4float Input
 // %VSmain = OpFunction %void None %10
 // %bb_entry = OpLabel
-// %result = OpVariable %_ptr_Function__struct_13 Function
+// %result = OpVariable %_ptr_Function_PSInput Function
 // %16 = OpLoad %v4float %7
 // %20 = OpAccessChain %_ptr_Function_v4float %result %int_0
 // OpStore %20 %16

+ 30 - 0
tools/clang/test/CodeGenSPIRV/type.struct.hlsl

@@ -0,0 +1,30 @@
+// Run: %dxc -T vs_6_0 -E main
+
+// CHECK:      OpName %S "S"
+// CHECK-NEXT: OpMemberName %S 0 "a"
+// CHECK-NEXT: OpMemberName %S 1 "b"
+// CHECK-NEXT: OpMemberName %S 2 "c"
+
+// CHECK:      OpName %T "T"
+// CHECK-NEXT: OpMemberName %T 0 "x"
+// CHECK-NEXT: OpMemberName %T 1 "y"
+// CHECK-NEXT: OpMemberName %T 2 "z"
+
+// CHECK:      %S = OpTypeStruct %uint %v4float %mat2v3float
+struct S {
+    uint a;
+    float4 b;
+    float2x3 c;
+};
+
+// CHECK:      %T = OpTypeStruct %S %v3int %S
+struct T {
+    S x;
+    int3 y;
+    S z;
+};
+
+void main() {
+    S s;
+    T t;
+}

+ 1 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -38,6 +38,7 @@ TEST_F(WholeFileTest, ConstantPixelShader) {
 TEST_F(FileTest, ScalarTypes) { runFileTest("type.scalar.hlsl"); }
 TEST_F(FileTest, VectorTypes) { runFileTest("type.vector.hlsl"); }
 TEST_F(FileTest, MatrixTypes) { runFileTest("type.matrix.hlsl"); }
+TEST_F(FileTest, StructTypes) { runFileTest("type.struct.hlsl"); }
 TEST_F(FileTest, TypedefTypes) { runFileTest("type.typedef.hlsl"); }
 
 // For constants