Explorar o código

Fix payload/attr/param sizes, add test, fix some metadata non-determinism

Tex Riddell %!s(int64=7) %!d(string=hai) anos
pai
achega
fa03d3362a

+ 29 - 6
lib/HLSL/DxilModule.cpp

@@ -1282,17 +1282,40 @@ void DxilModule::EmitDxilMetadata() {
   if (m_pSM->IsLib()) {
     NamedMDNode *fnProps = m_pModule->getOrInsertNamedMetadata(
         DxilMDHelper::kDxilFunctionPropertiesMDName);
-    for (auto &&pair : m_DxilFunctionPropsMap) {
-      const hlsl::DxilFunctionProps *props = pair.second.get();
-      MDTuple *pProps = m_pMDHelper->EmitDxilFunctionProps(props, pair.first);
+
+    // Sort functions by name to keep metadata deterministic
+    vector<Function *> funcOrder;
+    funcOrder.reserve(std::max(m_DxilFunctionPropsMap.size(),
+                               m_DxilEntrySignatureMap.size()));
+
+    std::transform( m_DxilFunctionPropsMap.begin(),
+                    m_DxilFunctionPropsMap.end(),
+                    std::back_inserter(funcOrder),
+                    [](auto &p) -> Function* { return p.first; } );
+    std::sort(funcOrder.begin(), funcOrder.end(), [](Function *F1, Function *F2) {
+      return F1->getName() < F2->getName();
+    });
+
+    for (auto F : funcOrder) {
+      MDTuple *pProps = m_pMDHelper->EmitDxilFunctionProps(&GetDxilFunctionProps(F), F);
       fnProps->addOperand(pProps);
     }
+    funcOrder.clear();
 
     NamedMDNode *entrySigs = m_pModule->getOrInsertNamedMetadata(
         DxilMDHelper::kDxilEntrySignaturesMDName);
-    for (auto &&pair : m_DxilEntrySignatureMap) {
-      Function *F = pair.first;
-      DxilEntrySignature *Sig = pair.second.get();
+
+    // Sort functions by name to keep metadata deterministic
+    std::transform( m_DxilEntrySignatureMap.begin(),
+                    m_DxilEntrySignatureMap.end(),
+                    std::back_inserter(funcOrder),
+                    [](auto &p) -> Function* { return p.first; } );
+    std::sort(funcOrder.begin(), funcOrder.end(), [](Function *F1, Function *F2) {
+      return F1->getName() < F2->getName();
+    });
+
+    for (auto F : funcOrder) {
+      DxilEntrySignature *Sig = &GetDxilEntrySignature(F);
       MDTuple *pSig = m_pMDHelper->EmitDxilSignatures(*Sig);
       entrySigs->addOperand(
           MDTuple::get(m_Ctx, {ValueAsMetadata::get(F), pSig}));

+ 3 - 3
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -1737,7 +1737,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
               "payload and attribute structures must be user defined types with only numeric contents."));
           } else {
             DataLayout DL(&this->TheModule);
-            unsigned size = DL.getTypeAllocSize(F->getFunctionType()->getFunctionParamType(ArgNo));
+            unsigned size = DL.getTypeAllocSize(F->getFunctionType()->getFunctionParamType(ArgNo)->getPointerElementType());
             if (0 == ArgNo)
               funcProps->ShaderProps.Ray.payloadSizeInBytes = size;
             else
@@ -1762,7 +1762,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
               "ray payload parameter must be a user defined type with only numeric contents."));
           } else {
             DataLayout DL(&this->TheModule);
-            unsigned size = DL.getTypeAllocSize(F->getFunctionType()->getFunctionParamType(ArgNo));
+            unsigned size = DL.getTypeAllocSize(F->getFunctionType()->getFunctionParamType(ArgNo)->getPointerElementType());
             funcProps->ShaderProps.Ray.payloadSizeInBytes = size;
           }
         }
@@ -1784,7 +1784,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
               "callable parameter must be a user defined type with only numeric contents."));
           } else {
             DataLayout DL(&this->TheModule);
-            unsigned size = DL.getTypeAllocSize(F->getFunctionType()->getFunctionParamType(ArgNo));
+            unsigned size = DL.getTypeAllocSize(F->getFunctionType()->getFunctionParamType(ArgNo)->getPointerElementType());
             funcProps->ShaderProps.Ray.paramSizeInBytes = size;
           }
         }

+ 97 - 0
tools/clang/test/CodeGenHLSL/quick-test/raytracing_udt_sizes.hlsl

@@ -0,0 +1,97 @@
+// RUN: %dxc -T lib_6_3 -enable-16bit-types %s | FileCheck %s
+
+///////////////////////////////////////
+// CHECK: !{void (%struct.Payload_20*, %struct.BuiltInTriangleIntersectionAttributes*)* @"\01?anyhit1@@YAXUPayload_20@@UBuiltInTriangleIntersectionAttributes@@@Z", i32 9, i32 20, i32 8}
+
+struct Payload_20 {
+  float3 color;
+  uint2 pos;
+  // align 4
+};
+
+[shader("anyhit")]
+void anyhit1( inout Payload_20 payload,
+                  in BuiltInTriangleIntersectionAttributes attr )
+{
+}
+
+///////////////////////////////////////
+// CHECK: !{void (%struct.Params_16*)* @"\01?callable4@@YAXUParams_16@@@Z", i32 12, i32 16}
+
+struct Params_16 {
+  int64_t i;
+  int16_t i16;
+  // align 8
+};
+
+[shader("callable")]
+void callable4( inout Params_16 params )
+{
+}
+
+///////////////////////////////////////
+// CHECK: !{void (%struct.Payload_16*, %struct.Attributes_12*)* @"\01?closesthit2@@YAXUPayload_16@@UAttributes_12@@@Z", i32 10, i32 16, i32 12}
+
+struct Payload_16 {
+  half a;
+  half2 h2;
+  half b;
+  half c;
+  // pad 2 bytes
+  uint u;
+  // align 4
+};
+
+struct Attributes_12 {
+  half a;
+  // pad 2 bytes
+  bool b;   // 4 bytes for bool
+  uint16_t c;
+  // align 4
+};
+
+[shader("closesthit")]
+void closesthit2( inout Payload_16 payload,
+                  in Attributes_12 attr )
+{
+}
+
+///////////////////////////////////////
+// CHECK: !{void (%struct.Payload_10*, %struct.Attributes_40*)* @"\01?closesthit3@@YAXUPayload_10@@UAttributes_40@@@Z", i32 10, i32 10, i32 40}
+
+struct Payload_10 {
+  half4 color;
+  int16_t i16;
+  // align 2
+};
+
+struct Attributes_40 {
+  half a;
+  // pad 6 bytes
+  double d;
+  int16_t2 w2;
+  // pad 4 bytes
+  int64_t i;
+  half h;
+  // align 8
+};
+
+[shader("closesthit")]
+void closesthit3( inout Payload_10 payload,
+                  in Attributes_40 attr )
+{
+}
+
+///////////////////////////////////////
+// CHECK: !{void (%struct.Payload_8*)* @"\01?miss4@@YAXUPayload_8@@@Z", i32 11, i32 8}
+
+struct Payload_8 {
+  half color;
+  int16_t3 i16;
+  // align 2
+};
+
+[shader("miss")]
+void miss4( inout Payload_8 payload )
+{
+}