2
0
Эх сурвалжийг харах

Update helpers for working with precise (#333)

This commit updates the `IsPrecise` helper to also check for the fast math
flags on an instruction. Now, the function will return the correct value
for any instruction. Before it only worked for intrinsic calls.

We also added three helpers for working with fast math flags.
We provide getters and setters for precise fast math flags and
a helper to say if fast math flags are preserved through serialization
and deserialization.

Finally, we added a unit test for DxilModule that we can use to ensure
that the `IsPrecise` helper returns the correct value.
David Peixotto 8 жил өмнө
parent
commit
2afe446dae

+ 21 - 3
include/dxc/HLSL/DxilModule.h

@@ -150,9 +150,27 @@ public:
 
   static DxilModule *TryGetDxilModule(llvm::Module *pModule);
 
-  // Return true if the instruction is marked precise or if global
-  // refactoring is disabled.
-  bool IsPrecise(llvm::Instruction *inst);
+  // Helpers for working with precise.
+
+  // Return true if the instruction should be considered precise.
+  //
+  // An instruction can be marked precise in the following ways:
+  //
+  // 1. Global refactoring is disabled.
+  // 2. The instruction has a precise metadata annotation.
+  // 3. The instruction has precise fast math flags set.
+  //
+  bool IsPrecise(const llvm::Instruction *inst) const;
+
+  // Check if the instruction has fast math flags configured to indicate
+  // the instruction is precise.
+  static bool HasPreciseFastMathFlags(const llvm::Instruction *inst);
+  
+  // Set fast math flags configured to indicate the instruction is precise.
+  static void SetPreciseFastMathFlags(llvm::Instruction *inst);
+  
+  // True if fast math flags are preserved across serialize/deserialize.
+  static bool PreservesFastMathFlags(const llvm::Instruction *inst);
 
 public:
   // Shader properties.

+ 5 - 3
lib/HLSL/DxilGenerationPass.cpp

@@ -3085,12 +3085,14 @@ static void PropagatePreciseAttributeOnOperand(Value *V, DxilTypeSystem &typeSys
     return;
 
   // Skip inst already marked.
-  if (!I->hasUnsafeAlgebra())
+  if (DxilModule::HasPreciseFastMathFlags(I))
     return;
   // TODO: skip precise on integer type, sample instruction...
 
-  // Clear fast math.
-  I->copyFastMathFlags(FastMathFlags());
+  // Set precise fast math on those instructions that support it.
+  if (DxilModule::PreservesFastMathFlags(I))
+    DxilModule::SetPreciseFastMathFlags(I);
+
   // Fast math not work on call, use metadata.
   if (CallInst *CI = dyn_cast<CallInst>(I))
     HLModule::MarkPreciseAttributeWithMetadata(CI);

+ 37 - 2
lib/HLSL/DxilModule.cpp

@@ -21,6 +21,7 @@
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Metadata.h"
 #include "llvm/IR/Module.h"
+#include "llvm/IR/Operator.h"
 #include "llvm/IR/DebugInfo.h"
 #include "llvm/IR/DiagnosticInfo.h"
 #include "llvm/IR/DiagnosticPrinter.h"
@@ -1346,10 +1347,44 @@ hlsl::DxilModule *hlsl::DxilModule::TryGetDxilModule(llvm::Module *pModule) {
   return pDxilModule;
 }
 
-bool DxilModule::IsPrecise(Instruction *inst) {
+// Check if the instruction has fast math flags configured to indicate
+// the instruction is precise.
+// Precise fast math flags means none of the fast math flags are set.
+bool DxilModule::HasPreciseFastMathFlags(const Instruction *inst) {
+  return isa<FPMathOperator>(inst) && !inst->getFastMathFlags().any();
+}
+
+// Set fast math flags configured to indicate the instruction is precise.
+void DxilModule::SetPreciseFastMathFlags(llvm::Instruction *inst) {
+  assert(isa<FPMathOperator>(inst));
+  inst->copyFastMathFlags(FastMathFlags());
+}
+
+// True if fast math flags are preserved across serialization/deserialization
+// of the dxil module.
+//
+// We need to check for this when querying fast math flags for preciseness
+// otherwise we will be overly conservative by reporting instructions precise
+// because their fast math flags were not preserved.
+//
+// Currently we restrict it to the instruction types that have fast math
+// preserved in the bitcode. We can expand this by converting fast math
+// flags to dx.precise metadata during serialization and back to fast
+// math flags during deserialization.
+bool DxilModule::PreservesFastMathFlags(const llvm::Instruction *inst) {
+  return
+    isa<FPMathOperator>(inst) && (isa<BinaryOperator>(inst) || isa<FCmpInst>(inst));
+}
+
+bool DxilModule::IsPrecise(const Instruction *inst) const {
   if (m_ShaderFlags.GetDisableMathRefactoring())
     return true;
-  return DxilMDHelper::IsMarkedPrecise(inst);
+  else if (DxilMDHelper::IsMarkedPrecise(inst))
+    return true;
+  else if (PreservesFastMathFlags(inst))
+    return HasPreciseFastMathFlags(inst);
+  else
+    return false;
 }
 
 } // namespace hlsl

+ 3 - 0
tools/clang/unittests/HLSL/CMakeLists.txt

@@ -10,9 +10,11 @@ set( LLVM_LINK_COMPONENTS
   dxcsupport
   hlsl
   option
+  bitreader
   bitwriter
   analysis
   ipa
+  irreader
   )
 
 add_clang_library(clang-hlsl-tests SHARED
@@ -20,6 +22,7 @@ add_clang_library(clang-hlsl-tests SHARED
   CompilationResult.h
   CompilerTest.cpp
   DxilContainerTest.cpp
+  DxilModuleTest.cpp
   DXIsenseTest.cpp
   ExecutionTest.cpp
   ExtensionTest.cpp

+ 389 - 0
tools/clang/unittests/HLSL/DxilModuleTest.cpp

@@ -0,0 +1,389 @@
+///////////////////////////////////////////////////////////////////////////////
+//                                                                           //
+// Copyright (C) Microsoft Corporation. All rights reserved.                 //
+// DxilModuleTest.cpp                                                        //
+//                                                                           //
+// Provides unit tests for DxilModule.                                       //
+//                                                                           //
+///////////////////////////////////////////////////////////////////////////////
+
+#include "CompilationResult.h"
+#include "WexTestClass.h"
+#include "HlslTestUtils.h"
+#include "DxcTestUtils.h"
+#include "dxc/Support/microcom.h"
+#include "dxc/dxcapi.internal.h"
+#include "dxc/HLSL/HLOperationLowerExtension.h"
+#include "dxc/HlslIntrinsicOp.h"
+#include "dxc/HLSL/DxilOperations.h"
+#include "dxc/HLSL/DxilInstructions.h"
+#include "dxc/HLSL/DxilContainer.h"
+#include "dxc/HLSL/DxilModule.h"
+#include "llvm/Support/Regex.h"
+#include "llvm/Support/MSFileSystem.h"
+#include "llvm/Support/FileSystem.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/ErrorOr.h"
+#include "llvm/BitCode/ReaderWriter.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/InstIterator.h"
+
+using namespace hlsl;
+using namespace llvm;
+
+///////////////////////////////////////////////////////////////////////////////
+// DxilModule unit tests.
+
+class DxilModuleTest {
+public:
+  BEGIN_TEST_CLASS(DxilModuleTest)
+    TEST_CLASS_PROPERTY(L"Parallel", L"true")
+    TEST_METHOD_PROPERTY(L"Priority", L"0")
+  END_TEST_CLASS()
+
+  dxc::DxcDllSupport m_dllSupport;
+
+  // Basic loading tests.
+  TEST_METHOD(LoadDxilModule_1_0);
+  TEST_METHOD(LoadDxilModule_1_1);
+  
+  // Precise query tests.
+  TEST_METHOD(Precise1);
+  TEST_METHOD(Precise2);
+  TEST_METHOD(Precise3);
+  TEST_METHOD(Precise4);
+  TEST_METHOD(Precise5);
+  TEST_METHOD(Precise6);
+  TEST_METHOD(Precise7);
+};
+
+///////////////////////////////////////////////////////////////////////////////
+// Compilation and dxil module loading support.
+
+namespace {
+class Compiler {
+public:
+  Compiler(dxc::DxcDllSupport &dll) 
+    : m_dllSupport(dll) 
+    , m_msf(CreateMSFileSystem())
+    , m_pts(m_msf.get())
+  {
+    VERIFY_SUCCEEDED(m_dllSupport.Initialize());
+    VERIFY_SUCCEEDED(m_dllSupport.CreateInstance(CLSID_DxcCompiler, &pCompiler));
+  }
+  
+  IDxcOperationResult *Compile(const char *program, LPCWSTR shaderModel = L"ps_6_0") {
+    return Compile(program, shaderModel, {}, {});
+  }
+
+  IDxcOperationResult *Compile(const char *program, LPCWSTR shaderModel, const std::vector<LPCWSTR> &arguments, const std::vector<DxcDefine> defs ) {
+    Utf8ToBlob(m_dllSupport, program, &pCodeBlob);
+    VERIFY_SUCCEEDED(pCompiler->Compile(pCodeBlob, L"hlsl.hlsl", L"main",
+      shaderModel,
+      const_cast<LPCWSTR *>(arguments.data()), arguments.size(),
+      defs.data(), defs.size(),
+      nullptr, &pCompileResult));
+
+    return pCompileResult;
+  }
+
+  std::string Disassemble() {
+    CComPtr<IDxcBlob> pBlob;
+    CheckOperationSucceeded(pCompileResult, &pBlob);
+    return DisassembleProgram(m_dllSupport, pBlob);
+  }
+
+  DxilModule &GetDxilModule() {
+    // Make sure we compiled successfully.
+    CComPtr<IDxcBlob> pBlob;
+    CheckOperationSucceeded(pCompileResult, &pBlob);
+    
+    // Verify we have a valid dxil container.
+    const DxilContainerHeader *pContainer =
+      IsDxilContainerLike(pBlob->GetBufferPointer(), pBlob->GetBufferSize());
+    VERIFY_IS_NOT_NULL(pContainer);
+    VERIFY_IS_TRUE(IsValidDxilContainer(pContainer, pBlob->GetBufferSize()));
+        
+    // Get Dxil part from container.
+    DxilPartIterator it = std::find_if(begin(pContainer), end(pContainer), DxilPartIsType(DFCC_DXIL));
+    VERIFY_IS_FALSE(it == end(pContainer));
+    
+    const DxilProgramHeader *pProgramHeader =
+        reinterpret_cast<const DxilProgramHeader *>(GetDxilPartData(*it));
+    VERIFY_IS_TRUE(IsValidDxilProgramHeader(pProgramHeader, (*it)->PartSize));
+        
+    // Get a pointer to the llvm bitcode.
+    const char *pIL;
+    uint32_t pILLength;
+    GetDxilProgramBitcode(pProgramHeader, &pIL, &pILLength);
+      
+    // Parse llvm bitcode into a module.
+    std::unique_ptr<llvm::MemoryBuffer> pBitcodeBuf(
+          llvm::MemoryBuffer::getMemBuffer(llvm::StringRef(pIL, pILLength), "", false));
+    llvm::ErrorOr<std::unique_ptr<llvm::Module>>
+      pModule(llvm::parseBitcodeFile(pBitcodeBuf->getMemBufferRef(), m_llvmContext));
+    if (std::error_code ec = pModule.getError()) {
+      VERIFY_FAIL();
+    }
+    m_module = std::move(pModule.get());
+
+    // Grab the dxil module;
+    DxilModule *DM = DxilModule::TryGetDxilModule(m_module.get());
+    VERIFY_IS_NOT_NULL(DM);
+    return *DM;
+  }
+
+private:
+  static ::llvm::sys::fs::MSFileSystem *CreateMSFileSystem() {
+    ::llvm::sys::fs::MSFileSystem *msfPtr;
+    VERIFY_SUCCEEDED(CreateMSFileSystemForDisk(&msfPtr));
+    return msfPtr;
+  }
+
+  dxc::DxcDllSupport &m_dllSupport;
+  CComPtr<IDxcCompiler> pCompiler;
+  CComPtr<IDxcBlobEncoding> pCodeBlob;
+  CComPtr<IDxcOperationResult> pCompileResult;
+  llvm::LLVMContext m_llvmContext;
+  std::unique_ptr<llvm::Module> m_module;
+  std::unique_ptr<::llvm::sys::fs::MSFileSystem> m_msf;
+  ::llvm::sys::fs::AutoPerThreadSystem m_pts;
+};
+}
+
+///////////////////////////////////////////////////////////////////////////////
+// Unit Test Implementation
+TEST_F(DxilModuleTest, LoadDxilModule_1_0) {
+  Compiler c(m_dllSupport);
+  c.Compile(
+    "float4 main() : SV_Target {\n"
+    "  return 0;\n"
+    "}\n"
+    ,
+    L"ps_6_0"
+  );
+
+  // Basic sanity check on dxil version in dxil module.
+  DxilModule &DM = c.GetDxilModule();
+  unsigned vMajor, vMinor;
+  DM.GetDxilVersion(vMajor, vMinor);
+  VERIFY_IS_TRUE(vMajor == 1);
+  VERIFY_IS_TRUE(vMinor == 0);
+}
+
+TEST_F(DxilModuleTest, LoadDxilModule_1_1) {
+  Compiler c(m_dllSupport);
+  c.Compile(
+    "float4 main() : SV_Target {\n"
+    "  return 0;\n"
+    "}\n"
+    ,
+    L"ps_6_1"
+  );
+
+  // Basic sanity check on dxil version in dxil module.
+  DxilModule &DM = c.GetDxilModule();
+  unsigned vMajor, vMinor;
+  DM.GetDxilVersion(vMajor, vMinor);
+  VERIFY_IS_TRUE(vMajor == 1);
+  VERIFY_IS_TRUE(vMinor == 1);
+}
+
+TEST_F(DxilModuleTest, Precise1) {
+  Compiler c(m_dllSupport);
+  c.Compile(
+    "precise float main(float x : X, float y : Y) : SV_Target {\n"
+    "  return sqrt(x) + y;\n"
+    "}\n"
+  );
+
+  // Make sure sqrt and add are marked precise.
+  DxilModule &DM = c.GetDxilModule();
+  Function *F = DM.GetEntryFunction();
+  int numChecks = 0;
+  for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) {
+    Instruction *Inst = &*I;
+    if (DxilInst_Sqrt(Inst)) {
+      numChecks++;
+      VERIFY_IS_TRUE(DM.IsPrecise(Inst));
+    }
+    else if (LlvmInst_FAdd(Inst)) {
+      numChecks++;
+      VERIFY_IS_TRUE(DM.IsPrecise(Inst));
+    }
+  }
+  VERIFY_ARE_EQUAL(numChecks, 2);
+}
+
+TEST_F(DxilModuleTest, Precise2) {
+  Compiler c(m_dllSupport);
+  c.Compile(
+    "float main(float x : X, float y : Y) : SV_Target {\n"
+    "  return sqrt(x) + y;\n"
+    "}\n"
+  );
+
+  // Make sure sqrt and add are not marked precise.
+  DxilModule &DM = c.GetDxilModule();
+  Function *F = DM.GetEntryFunction();
+  int numChecks = 0;
+  for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) {
+    Instruction *Inst = &*I;
+    if (DxilInst_Sqrt(Inst)) {
+      numChecks++;
+      VERIFY_IS_FALSE(DM.IsPrecise(Inst));
+    }
+    else if (LlvmInst_FAdd(Inst)) {
+      numChecks++;
+      VERIFY_IS_FALSE(DM.IsPrecise(Inst));
+    }
+  }
+  VERIFY_ARE_EQUAL(numChecks, 2);
+}
+
+TEST_F(DxilModuleTest, Precise3) {
+  // TODO: Enable this test when precise metadata is inserted for Gis.
+  if (const bool GisIsBroken = true) return;
+  Compiler c(m_dllSupport);
+  c.Compile(
+    "float main(float x : X, float y : Y) : SV_Target {\n"
+    "  return sqrt(x) + y;\n"
+    "}\n",
+    L"ps_6_0",
+    { L"/Gis" }, {}
+  );
+
+  // Make sure sqrt and add are marked precise.
+  DxilModule &DM = c.GetDxilModule();
+  Function *F = DM.GetEntryFunction();
+  int numChecks = 0;
+  for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) {
+    Instruction *Inst = &*I;
+    if (DxilInst_Sqrt(Inst)) {
+      numChecks++;
+      VERIFY_IS_TRUE(DM.IsPrecise(Inst));
+    }
+    else if (LlvmInst_FAdd(Inst)) {
+      numChecks++;
+      VERIFY_IS_TRUE(DM.IsPrecise(Inst));
+    }
+  }
+  VERIFY_ARE_EQUAL(numChecks, 2);
+}
+
+TEST_F(DxilModuleTest, Precise4) {
+  Compiler c(m_dllSupport);
+  c.Compile(
+    "float main(float x : X, float y : Y) : SV_Target {\n"
+    "  precise float sx = 1 / sqrt(x);\n"
+    "  return sx + y;\n"
+    "}\n"
+  );
+
+  // Make sure sqrt and div are marked precise, and add is not.
+  DxilModule &DM = c.GetDxilModule();
+  Function *F = DM.GetEntryFunction();
+  int numChecks = 0;
+  for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) {
+    Instruction *Inst = &*I;
+    if (DxilInst_Sqrt(Inst)) {
+      numChecks++;
+      VERIFY_IS_TRUE(DM.IsPrecise(Inst));
+    }
+    else if (LlvmInst_FDiv(Inst)) {
+      numChecks++;
+      VERIFY_IS_TRUE(DM.IsPrecise(Inst));
+    }
+    else if (LlvmInst_FAdd(Inst)) {
+      numChecks++;
+      VERIFY_IS_FALSE(DM.IsPrecise(Inst));
+    }
+  }
+  VERIFY_ARE_EQUAL(numChecks, 3);
+}
+
+TEST_F(DxilModuleTest, Precise5) {
+  Compiler c(m_dllSupport);
+  c.Compile(
+    "float C[10];\n"
+    "float main(float x : X, float y : Y, int i : I) : SV_Target {\n"
+    "  float A[2];\n"
+    "  A[0] = x;\n"
+    "  A[1] = y;\n"
+    "  return A[i] + C[i];\n"
+    "}\n"
+  );
+
+  // Make sure load and extract value are not reported as precise.
+  DxilModule &DM = c.GetDxilModule();
+  Function *F = DM.GetEntryFunction();
+  int numChecks = 0;
+  for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) {
+    Instruction *Inst = &*I;
+    if (LlvmInst_ExtractValue(Inst)) {
+      numChecks++;
+      VERIFY_IS_FALSE(DM.IsPrecise(Inst));
+    }
+    else if (LlvmInst_Load(Inst)) {
+      numChecks++;
+      VERIFY_IS_FALSE(DM.IsPrecise(Inst));
+    }
+    else if (LlvmInst_FAdd(Inst)) {
+      numChecks++;
+      VERIFY_IS_FALSE(DM.IsPrecise(Inst));
+    }
+  }
+  VERIFY_ARE_EQUAL(numChecks, 3);
+}
+
+TEST_F(DxilModuleTest, Precise6) {
+  Compiler c(m_dllSupport);
+  c.Compile(
+    "precise float2 main(float2 x : A, float2 y : B) : SV_Target {\n"
+    "  return sqrt(x * y);\n"
+    "}\n"
+  );
+
+  // Make sure sqrt and mul are marked precise.
+  DxilModule &DM = c.GetDxilModule();
+  Function *F = DM.GetEntryFunction();
+  int numChecks = 0;
+  for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) {
+    Instruction *Inst = &*I;
+    if (DxilInst_Sqrt(Inst)) {
+      numChecks++;
+      VERIFY_IS_TRUE(DM.IsPrecise(Inst));
+    }
+    else if (LlvmInst_FMul(Inst)) {
+      numChecks++;
+      VERIFY_IS_TRUE(DM.IsPrecise(Inst));
+    }
+  }
+  VERIFY_ARE_EQUAL(numChecks, 4);
+}
+
+TEST_F(DxilModuleTest, Precise7) {
+  Compiler c(m_dllSupport);
+  c.Compile(
+    "float2 main(float2 x : A, float2 y : B) : SV_Target {\n"
+    "  return sqrt(x * y);\n"
+    "}\n"
+  );
+
+  // Make sure sqrt and mul are not marked precise.
+  DxilModule &DM = c.GetDxilModule();
+  Function *F = DM.GetEntryFunction();
+  int numChecks = 0;
+  for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) {
+    Instruction *Inst = &*I;
+    if (DxilInst_Sqrt(Inst)) {
+      numChecks++;
+      VERIFY_IS_FALSE(DM.IsPrecise(Inst));
+    }
+    else if (LlvmInst_FMul(Inst)) {
+      numChecks++;
+      VERIFY_IS_FALSE(DM.IsPrecise(Inst));
+    }
+  }
+  VERIFY_ARE_EQUAL(numChecks, 4);
+}