2
0
Эх сурвалжийг харах

Disable NRVO for entry functions and patch constant functions (#4466)

Disable NRVO for entry functions and patch constant functions.

For entry functions (and patch constant functions) each write to an output argument creates a new call to dx.storeOutput (or dx.storePatchConstant). With RVO enabled, every write to the return variable becomes an output instruction, which could be excessive. To avoid this, disable NRVO for any entry functions and patch constant functions.
Adam Yang 3 жил өмнө
parent
commit
24909be7d4

+ 1 - 0
tools/clang/include/clang/AST/HlslTypes.h

@@ -446,6 +446,7 @@ bool IsUserDefinedRecordType(clang::QualType type);
 bool DoesTypeDefineOverloadedOperator(clang::QualType typeWithOperator,
                                       clang::OverloadedOperatorKind opc,
                                       clang::QualType paramType);
+bool IsPatchConstantFunctionDecl(const clang::FunctionDecl *FD);
 
 /// <summary>Adds a function declaration to the specified class record.</summary>
 /// <param name="context">ASTContext that owns declarations.</param>

+ 2 - 0
tools/clang/include/clang/Sema/SemaHLSL.h

@@ -103,6 +103,8 @@ clang::OverloadingResult GetBestViableFunction(
   clang::OverloadCandidateSet& set,
   clang::OverloadCandidateSet::iterator& Best);
 
+bool ShouldSkipNRVO(clang::Sema &sema, clang::QualType returnType, clang::VarDecl *VD, clang::FunctionDecl *FD);
+
 /// <summary>Processes an attribute for a declaration.</summary>
 /// <param name="S">Sema with context.</param>
 /// <param name="D">Annotated declaration.</param>

+ 1 - 48
tools/clang/lib/AST/ASTContextHLSL.cpp

@@ -1205,53 +1205,6 @@ UnusualAnnotation* hlsl::UnusualAnnotation::CopyToASTContext(ASTContext& Context
   return (UnusualAnnotation*)result;
 }
 
-static bool HasTessFactorSemantic(const ValueDecl *decl) {
-  for (const UnusualAnnotation *it : decl->getUnusualAnnotations()) {
-    if (it->getKind() == UnusualAnnotation::UA_SemanticDecl) {
-      const SemanticDecl *sd = cast<SemanticDecl>(it);
-      const Semantic *pSemantic = Semantic::GetByName(sd->SemanticName);
-      if (pSemantic && pSemantic->GetKind() == Semantic::Kind::TessFactor)
-        return true;
-    }
-  }
-  return false;
-}
-
-static bool HasTessFactorSemanticRecurse(const ValueDecl *decl, QualType Ty) {
-  if (Ty->isBuiltinType() || hlsl::IsHLSLVecMatType(Ty))
-    return false;
-
-  if (const RecordType *RT = Ty->getAsStructureType()) {
-    RecordDecl *RD = RT->getDecl();
-    for (FieldDecl *fieldDecl : RD->fields()) {
-      if (HasTessFactorSemanticRecurse(fieldDecl, fieldDecl->getType()))
-        return true;
-    }
-    return false;
-  }
-
-  if (Ty->getAsArrayTypeUnsafe())
-    return HasTessFactorSemantic(decl);
-
-  return false;
-}
-
 bool ASTContext::IsPatchConstantFunctionDecl(const FunctionDecl *FD) const {
-  // This checks whether the function is structurally capable of being a patch
-  // constant function, not whether it is in fact the patch constant function
-  // for the entry point of a compiled hull shader (which may not have been
-  // seen yet). So the answer is conservative.
-  if (!FD->getReturnType()->isVoidType()) {
-    // Try to find TessFactor in return type.
-    if (HasTessFactorSemanticRecurse(FD, FD->getReturnType()))
-      return true;
-  }
-  // Try to find TessFactor in out param.
-  for (const ParmVarDecl *param : FD->params()) {
-    if (param->hasAttr<HLSLOutAttr>()) {
-      if (HasTessFactorSemanticRecurse(param, param->getType()))
-        return true;
-    }
-  }
-  return false;
+  return hlsl::IsPatchConstantFunctionDecl(FD);
 }

+ 55 - 0
tools/clang/lib/AST/HlslTypes.cpp

@@ -14,6 +14,7 @@
 ///////////////////////////////////////////////////////////////////////////////
 
 #include "dxc/Support/Global.h"
+#include "dxc/DXIL/DxilSemantic.h"
 #include "clang/AST/CanonicalType.h"
 #include "clang/AST/DeclTemplate.h"
 #include "clang/AST/HlslTypes.h"
@@ -600,6 +601,60 @@ bool IsUserDefinedRecordType(clang::QualType type) {
   return false;
 }
 
+static bool HasTessFactorSemantic(const ValueDecl *decl) {
+  for (const UnusualAnnotation *it : decl->getUnusualAnnotations()) {
+    if (it->getKind() == UnusualAnnotation::UA_SemanticDecl) {
+      const SemanticDecl *sd = cast<SemanticDecl>(it);
+      StringRef semanticName;
+      unsigned int index = 0;
+      Semantic::DecomposeNameAndIndex(sd->SemanticName, &semanticName, &index);
+      const hlsl::Semantic *pSemantic = hlsl::Semantic::GetByName(semanticName);
+      if (pSemantic && pSemantic->GetKind() == hlsl::Semantic::Kind::TessFactor)
+        return true;
+    }
+  }
+  return false;
+}
+
+static bool HasTessFactorSemanticRecurse(const ValueDecl *decl, QualType Ty) {
+  if (Ty->isBuiltinType() || hlsl::IsHLSLVecMatType(Ty))
+    return false;
+
+  if (const RecordType *RT = Ty->getAsStructureType()) {
+    RecordDecl *RD = RT->getDecl();
+    for (FieldDecl *fieldDecl : RD->fields()) {
+      if (HasTessFactorSemanticRecurse(fieldDecl, fieldDecl->getType()))
+        return true;
+    }
+    return false;
+  }
+
+  if (Ty->getAsArrayTypeUnsafe())
+    return HasTessFactorSemantic(decl);
+
+  return false;
+}
+
+bool IsPatchConstantFunctionDecl(const clang::FunctionDecl *FD) {
+  // This checks whether the function is structurally capable of being a patch
+  // constant function, not whether it is in fact the patch constant function
+  // for the entry point of a compiled hull shader (which may not have been
+  // seen yet). So the answer is conservative.
+  if (!FD->getReturnType()->isVoidType()) {
+    // Try to find TessFactor in return type.
+    if (HasTessFactorSemanticRecurse(FD, FD->getReturnType()))
+      return true;
+  }
+  // Try to find TessFactor in out param.
+  for (const ParmVarDecl *param : FD->params()) {
+    if (param->hasAttr<HLSLOutAttr>()) {
+      if (HasTessFactorSemanticRecurse(param, param->getType()))
+        return true;
+    }
+  }
+  return false;
+}
+
 bool DoesTypeDefineOverloadedOperator(clang::QualType typeWithOperator,
                                       clang::OverloadedOperatorKind opc,
                                       clang::QualType paramType) {

+ 58 - 0
tools/clang/lib/Sema/SemaHLSL.cpp

@@ -11069,6 +11069,64 @@ unsigned hlsl::CaculateInitListArraySizeForHLSL(
   }
 }
 
+// NRVO unsafe for a variety of cases in HLSL
+// - vectors/matrix with bool component types
+// - attributes not captured to QualType, such as precise and globallycoherent
+bool hlsl::ShouldSkipNRVO(clang::Sema& sema, clang::QualType returnType, clang::VarDecl *VD, clang::FunctionDecl *FD) {
+  // exclude vectors/matrix (not treated as record type)
+  // NRVO breaks on bool component type due to diff between
+  // i32 memory and i1 register representation
+  if (hlsl::IsHLSLVecMatType(returnType))
+    return true;
+  QualType ArrayEltTy = returnType;
+  while (const clang::ArrayType *AT =
+             sema.getASTContext().getAsArrayType(ArrayEltTy)) {
+    ArrayEltTy = AT->getElementType();
+  }
+  // exclude resource for globallycoherent.
+  if (hlsl::IsHLSLResourceType(ArrayEltTy))
+    return true;
+  // exclude precise.
+  if (VD->hasAttr<HLSLPreciseAttr>()) {
+    return true;
+  }
+  if (FD) {
+    // propagate precise the the VD.
+    if (FD->hasAttr<HLSLPreciseAttr>()) {
+      VD->addAttr(FD->getAttr<HLSLPreciseAttr>());
+      return true;
+    }
+
+    // Don't do NRVO if this is an entry function or a patch contsant function.
+    // With NVRO, writing to the return variable directly writes to the output
+    // argument instead of to an alloca which gets copied to the output arg in one
+    // spot. This causes many extra dx.storeOutput's to be emitted.
+    //
+    // Check if this is an entry function the easy way if we're a library
+    if (const HLSLShaderAttr *Attr = FD->getAttr<HLSLShaderAttr>()) {
+      return true;
+    }
+    // Check if it's an entry function the hard way
+    if (!FD->getDeclContext()->isNamespace() && FD->isGlobal()) {
+      // Check if this is an entry function by comparing name
+      if (FD->getName() == sema.getLangOpts().HLSLEntryFunction) {
+        return true;
+      }
+
+      // See if it's the patch constant function
+      if (sema.getLangOpts().HLSLProfile.size() &&
+        (sema.getLangOpts().HLSLProfile[0] == 'h' /*For 'hs'*/ ||
+         sema.getLangOpts().HLSLProfile[0] == 'l' /*For 'lib'*/))
+      {
+        if (hlsl::IsPatchConstantFunctionDecl(FD))
+          return true;
+      }
+    }
+  }
+
+  return false;
+}
+
 bool hlsl::IsConversionToLessOrEqualElements(
   _In_ clang::Sema* self,
   const clang::ExprResult& sourceExpr,

+ 2 - 28
tools/clang/lib/Sema/SemaStmt.cpp

@@ -2702,34 +2702,8 @@ VarDecl *Sema::getCopyElisionCandidate(QualType ReturnType,
     return nullptr;
 
   // HLSL Change Begins: NRVO unsafe for a variety of cases in HLSL
-  // - vectors/matrix with bool component types
-  // - attributes not captured to QualType, such as precise and globallycoherent
-  if (getLangOpts().HLSL) {
-    // exclude vectors/matrix (not treated as record type)
-    // NRVO breaks on bool component type due to diff between
-    // i32 memory and i1 register representation
-    if (hlsl::IsHLSLVecMatType(ReturnType))
-      return nullptr;
-    QualType ArrayEltTy = ReturnType;
-    while (const clang::ArrayType *AT =
-               Context.getAsArrayType(ArrayEltTy)) {
-      ArrayEltTy = AT->getElementType();
-    }
-    // exclude resource for globallycoherent.
-    if (hlsl::IsHLSLResourceType(ArrayEltTy))
-      return nullptr;
-    // exclude precise.
-    if (VD->hasAttr<HLSLPreciseAttr>()) {
-      return nullptr;
-    }
-    // propagate precise the the VD.
-    if (const FunctionDecl *FD = getCurFunctionDecl()) {
-      if (FD->hasAttr<HLSLPreciseAttr>()) {
-        VD->addAttr(FD->getAttr<HLSLPreciseAttr>());
-        return nullptr;
-      }
-    }
-  }
+  if (getLangOpts().HLSL && hlsl::ShouldSkipNRVO(*this, ReturnType, VD, getCurFunctionDecl()))
+    return nullptr;
   // HLSL Change Ends
 
   if (isCopyElisionCandidate(ReturnType, VD, AllowFunctionParameter))

+ 190 - 0
tools/clang/test/HLSLFileCheck/hlsl/control_flow/return/EntryFunctionDisableNVRO.hlsl

@@ -0,0 +1,190 @@
+// RUN: %dxc /D PS /Tps_6_0 %s | FileCheck %s
+// RUN: %dxc /D PS /Tlib_6_3 %s | FileCheck %s
+// RUN: %dxc /D PS /Tps_6_0 -fcgl %s | FileCheck %s -check-prefix=FCGL
+// RUN: %dxc /D VS /Tvs_6_0 %s | FileCheck %s
+// RUN: %dxc /D VS /Tlib_6_3 %s | FileCheck %s
+// RUN: %dxc /D VS /Tvs_6_0 -fcgl %s | FileCheck %s -check-prefix=FCGL
+// RUN: %dxc /D DS /Tds_6_0 %s | FileCheck %s
+// RUN: %dxc /D DS /Tlib_6_3 %s | FileCheck %s
+// RUN: %dxc /D DS /Tds_6_0 -fcgl %s | FileCheck %s -check-prefix=FCGL
+// RUN: %dxc /D HS /Ths_6_0 %s | FileCheck %s -check-prefix=HS
+// RUN: %dxc /D HS /Tlib_6_3 %s | FileCheck %s -check-prefix=HS
+// RUN: %dxc /D HS /Ths_6_0 -fcgl %s | FileCheck %s -check-prefix=HS_FCGL
+
+// This test is for making sure we don't do NRVO for entry functions and patch constant functions.
+// For the normal compile version, we check there's the right number of storeOutput calls.
+// For the fcgl version, we check there is a return value alloca emitted.
+
+// CHECK: call void @dx.op.storeOutput
+// CHECK: call void @dx.op.storeOutput
+// CHECK: call void @dx.op.storeOutput
+// CHECK: call void @dx.op.storeOutput
+// CHECK-NOT: call void @dx.op.storeOutput
+
+// FCGL: alloca %struct.MyReturn
+
+cbuffer cb : register(b0) {
+    float4 foo;
+    bool a, b, c, d, e, f, g;
+}
+
+#ifdef PS
+struct MyReturn {
+    float4 member : SV_Target;
+};
+
+[shader("pixel")]
+MyReturn main() {
+    MyReturn ret = (MyReturn)0;
+    if (a) {
+        ret.member = foo;
+    }
+    else if (b) {
+        ret.member = foo * 2;
+        if (c) {
+            ret.member = foo * 4;
+        }
+        else if (d) {
+            ret.member = foo * 8;
+        }
+    }
+    return ret;
+}
+
+#elif defined(VS)
+
+struct MyReturn {
+    float4 member : SV_Position;
+};
+
+[shader("vertex")]
+MyReturn main() {
+    MyReturn ret = (MyReturn)0;
+    if (a) {
+        ret.member = foo;
+    }
+    else if (b) {
+        ret.member = foo * 2;
+        if (c) {
+            ret.member = foo * 4;
+        }
+        else if (d) {
+            ret.member = foo * 8;
+        }
+    }
+    return ret;
+}
+
+#elif defined(DS)
+
+struct MyReturn {
+    float4 member : SV_Position;
+};
+
+[domain("tri")]
+[shader("domain")]
+MyReturn main() {
+    MyReturn ret = (MyReturn)0;
+    if (a) {
+        ret.member = foo;
+    }
+    else if (b) {
+        ret.member = foo * 2;
+        if (c) {
+            ret.member = foo * 4;
+        }
+        else if (d) {
+            ret.member = foo * 8;
+        }
+    }
+    return ret;
+}
+
+#elif defined(HS)
+
+// HS-LABEL: @"\01?patch_const@@YA?AUMyPatchConstantReturn@@XZ"
+// HS: call void @dx.op.storePatchConstant
+// HS: call void @dx.op.storePatchConstant
+// HS: call void @dx.op.storePatchConstant
+// HS: call void @dx.op.storePatchConstant
+// HS-NOT: call void @dx.op.storePatchConstant
+
+// HS-LABEL: @main
+// HS: call void @dx.op.storeOutput
+// HS: call void @dx.op.storeOutput
+// HS: call void @dx.op.storeOutput
+// HS: call void @dx.op.storeOutput
+// HS-NOT: call void @dx.op.storeOutput
+
+// HS_FCGL-DAG: alloca %struct.MyReturn
+// HS_FCGL-DAG: alloca %struct.MyPatchConstantReturn
+
+
+struct MyPatchConstantReturn {
+    float member[3] : SV_TessFactor0; // <-- with 0;
+    float member2 : SV_InsideTessFactor;
+};
+MyPatchConstantReturn patch_const() {
+    MyPatchConstantReturn ret = (MyPatchConstantReturn)0;
+    if (a) {
+        ret.member[0] = foo[0];
+        ret.member[1] = foo[1];
+        ret.member[2] = foo[2];
+        ret.member2 = foo[3];
+    }
+    else if (b) {
+        ret.member[0] = foo[0] * 2;
+        ret.member[1] = foo[1] * 2;
+        ret.member[2] = foo[2] * 2;
+        ret.member2    = foo[3] * 2;
+        if (c) {
+            ret.member[0] = foo[0] * 4;
+            ret.member[1] = foo[1] * 4;
+            ret.member[2] = foo[2] * 4;
+            ret.member2 = foo[3] * 4;
+        }
+        else if (d) {
+            ret.member[0] = foo[0] * 8;
+            ret.member[1] = foo[1] * 8;
+            ret.member[2] = foo[2] * 8;
+            ret.member2 = foo[3] * 8;
+        }
+    }
+    return ret;
+}
+
+struct MyReturn {
+    float4 member : SV_Position;
+};
+
+struct MyInput {
+    float4 member : POSITION;
+};
+
+[domain("tri")]
+[outputtopology("triangle_cw")]
+[patchconstantfunc("patch_const")]
+[partitioning("fractional_odd")]
+[outputcontrolpoints(3)]
+[maxtessfactor(9)]
+[shader("hull")]
+MyReturn main( const uint id : SV_OutputControlPointID,
+               const InputPatch< MyInput, 3 > points )
+{
+    MyReturn ret = (MyReturn)0;
+    if (a) {
+        ret.member = points[id].member;
+    }
+    else if (b) {
+        ret.member = points[id].member;
+        if (c) {
+            ret.member = points[id].member;
+        }
+        else if (d) {
+            ret.member = points[id].member;
+        }
+    }
+    return ret;
+}
+
+#endif