Quellcode durchsuchen

Transform mad(a, 0, b) into b. (#1150)

Xiang Li vor 7 Jahren
Ursprung
Commit
a489c2ec25

+ 43 - 0
include/llvm/Analysis/DxilSimplify.h

@@ -0,0 +1,43 @@
+//===-- DxilSimplify.h - Simplify Dxil operations ------*- C++ -*-===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+// Copyright (C) Microsoft Corporation. All rights reserved.
+//===----------------------------------------------------------------------===//
+//
+// This file declares routines for simplify dxil intrinsics when some operands
+// are constants.
+//
+// We hook into the llvm::SimplifyInstruction so the function
+// interfaces are dictated by what llvm provides.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_ANALYSIS_HLSLDXILSIMPLIFY_H
+#define LLVM_ANALYSIS_HLSLDXILSIMPLIFY_H
+#include "llvm/ADT/ArrayRef.h"
+
+namespace llvm {
+class Function;
+class Instruction;
+class Value;
+} // namespace llvm
+
+namespace hlsl {
+/// \brief Given a function and set of arguments, see if we can fold the
+/// result as dxil operation.
+///
+/// If this call could not be simplified returns null.
+llvm::Value *SimplifyDxilCall(llvm::Function *F,
+                              llvm::ArrayRef<llvm::Value *> Args,
+                              llvm::Instruction *I);
+
+/// CanSimplify
+/// Return true on dxil operation function which can be simplified.
+bool CanSimplify(const llvm::Function *F);
+} // namespace hlsl
+
+#endif

+ 1 - 0
lib/Analysis/CMakeLists.txt

@@ -27,6 +27,7 @@ add_llvm_library(LLVMAnalysis
   DominanceFrontier.cpp
   DxilConstantFolding.cpp
   DxilConstantFoldingExt.cpp
+  DxilSimplify.cpp
   IVUsers.cpp
   InstCount.cpp
   InstructionSimplify.cpp

+ 170 - 0
lib/Analysis/DxilSimplify.cpp

@@ -0,0 +1,170 @@
+//===-- DxilSimplify.cpp - Fold dxil intrinsics into constants -----===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+// Copyright (C) Microsoft Corporation. All rights reserved.
+//
+//===----------------------------------------------------------------------===//
+//
+//
+//===----------------------------------------------------------------------===//
+
+// simplify dxil op like mad 0, a, b->b.
+
+#include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/IRBuilder.h"
+
+#include "dxc/HLSL/DxilModule.h"
+#include "dxc/HLSL/DxilOperations.h"
+#include "llvm/Analysis/DxilConstantFolding.h"
+#include "llvm/Analysis/DxilSimplify.h"
+
+using namespace llvm;
+using namespace hlsl;
+
+namespace {
+DXIL::OpCode GetOpcode(Value *opArg) {
+  if (ConstantInt *ci = dyn_cast<ConstantInt>(opArg)) {
+    uint64_t opcode = ci->getLimitedValue();
+    if (opcode < static_cast<uint64_t>(OP::OpCode::NumOpCodes)) {
+      return static_cast<OP::OpCode>(opcode);
+    }
+  }
+  return DXIL::OpCode::NumOpCodes;
+}
+} // namespace
+
+namespace hlsl {
+bool CanSimplify(const llvm::Function *F) {
+  // Only simplify dxil functions when we have a valid dxil module.
+  if (!F->getParent()->HasDxilModule()) {
+    assert(!OP::IsDxilOpFunc(F) && "dx.op function with no dxil module?");
+    return false;
+  }
+
+  // Lookup opcode class in dxil module. Set default value to invalid class.
+  OP::OpCodeClass opClass = OP::OpCodeClass::NumOpClasses;
+  const bool found =
+      F->getParent()->GetDxilModule().GetOP()->GetOpCodeClass(F, opClass);
+
+  // Return true for those dxil operation classes we can simplify.
+  if (found) {
+    switch (opClass) {
+    default:
+      break;
+    case OP::OpCodeClass::Tertiary:
+      return true;
+    }
+  }
+
+  return false;
+}
+
+/// \brief Given a function and set of arguments, see if we can fold the
+/// result as dxil operation.
+///
+/// If this call could not be simplified returns null.
+Value *SimplifyDxilCall(llvm::Function *F, ArrayRef<Value *> Args,
+                        llvm::Instruction *I) {
+  if (!F->getParent()->HasDxilModule()) {
+    assert(!OP::IsDxilOpFunc(F) && "dx.op function with no dxil module?");
+    return nullptr;
+  }
+
+  DxilModule &DM = F->getParent()->GetDxilModule();
+  // Skip precise.
+  if (DM.IsPrecise(I))
+    return nullptr;
+
+  // Lookup opcode class in dxil module. Set default value to invalid class.
+  OP::OpCodeClass opClass = OP::OpCodeClass::NumOpClasses;
+  const bool found = DM.GetOP()->GetOpCodeClass(F, opClass);
+  if (!found)
+    return nullptr;
+
+  DXIL::OpCode opcode = GetOpcode(Args[0]);
+  if (opcode == DXIL::OpCode::NumOpCodes)
+    return nullptr;
+
+  if (CanConstantFoldCallTo(F)) {
+    bool bAllConstant = true;
+    SmallVector<Constant *, 4> ConstantArgs;
+    ConstantArgs.reserve(Args.size());
+    for (Value *V : Args) {
+      Constant *C = dyn_cast<Constant>(V);
+      if (!C) {
+        bAllConstant = false;
+        break;
+      }
+      ConstantArgs.push_back(C);
+    }
+
+    if (bAllConstant)
+      return hlsl::ConstantFoldScalarCall(F->getName(), F->getReturnType(),
+                                          ConstantArgs);
+  }
+
+  switch (opcode) {
+  default:
+    return nullptr;
+  case DXIL::OpCode::FMad: {
+    Value *op0 = Args[DXIL::OperandIndex::kTrinarySrc0OpIdx];
+    Value *op2 = Args[DXIL::OperandIndex::kTrinarySrc2OpIdx];
+    Constant *zero = ConstantFP::get(op0->getType(), 0);
+    if (op0 == zero)
+      return op2;
+    Value *op1 = Args[DXIL::OperandIndex::kTrinarySrc1OpIdx];
+    if (op1 == zero)
+      return op2;
+
+    Constant *one = ConstantFP::get(op0->getType(), 1);
+    if (op0 == one) {
+      IRBuilder<> Builder(I);
+      llvm::FastMathFlags FMF;
+      FMF.setUnsafeAlgebraHLSL();
+      Builder.SetFastMathFlags(FMF);
+      return Builder.CreateFAdd(op1, op2);
+    }
+    if (op1 == one) {
+      IRBuilder<> Builder(I);
+      llvm::FastMathFlags FMF;
+      FMF.setUnsafeAlgebraHLSL();
+      Builder.SetFastMathFlags(FMF);
+
+      return Builder.CreateFAdd(op0, op2);
+    }
+    return nullptr;
+  } break;
+  case DXIL::OpCode::IMad:
+  case DXIL::OpCode::UMad: {
+    Value *op0 = Args[DXIL::OperandIndex::kTrinarySrc0OpIdx];
+    Value *op2 = Args[DXIL::OperandIndex::kTrinarySrc2OpIdx];
+    Constant *zero = ConstantInt::get(op0->getType(), 0);
+    if (op0 == zero)
+      return op2;
+    Value *op1 = Args[DXIL::OperandIndex::kTrinarySrc1OpIdx];
+    if (op1 == zero)
+      return op2;
+
+    Constant *one = ConstantInt::get(op0->getType(), 1);
+    if (op0 == one) {
+      IRBuilder<> Builder(I);
+      return Builder.CreateAdd(op1, op2);
+    }
+    if (op1 == one) {
+      IRBuilder<> Builder(I);
+      return Builder.CreateAdd(op0, op2);
+    }
+    return nullptr;
+  } break;
+  }
+}
+
+} // namespace hlsl

+ 12 - 0
lib/Analysis/InstructionSimplify.cpp

@@ -34,6 +34,9 @@
 #include "llvm/IR/PatternMatch.h"
 #include "llvm/IR/ValueHandle.h"
 #include <algorithm>
+
+#include "llvm/Analysis/DxilSimplify.h" // HLSL Change - simplify dxil call.
+
 using namespace llvm;
 using namespace llvm::PatternMatch;
 
@@ -4072,6 +4075,15 @@ Value *llvm::SimplifyInstruction(Instruction *I, const DataLayout &DL,
     break;
   case Instruction::Call: {
     CallSite CS(cast<CallInst>(I));
+    // HLSL Change Begin - simplify dxil call.
+    if (hlsl::CanSimplify(CS.getCalledFunction())) {
+      SmallVector<Value *, 4> Args(CS.arg_begin(), CS.arg_end());
+      if (Value *DxilResult = hlsl::SimplifyDxilCall(CS.getCalledFunction(), Args, I)) {
+        Result = DxilResult;
+        break;
+      }
+    }
+    // HLSL Change End.
     Result = SimplifyCall(CS.getCalledValue(), CS.arg_begin(), CS.arg_end(), DL,
                           TLI, DT, AC, I);
     break;

+ 2 - 3
lib/Transforms/IPO/PassManagerBuilder.cpp

@@ -256,12 +256,11 @@ static void addHLSLPasses(bool HLSLHighLevel, unsigned OptLevel, hlsl::HLSLExten
   MPM.add(createDxilLegalizeStaticResourceUsePass());
   MPM.add(createDxilGenerationPass(NoOpt, ExtHelper));
   MPM.add(createDxilLoadMetadataPass()); // Ensure DxilModule is loaded for optimizations.
-
-  MPM.add(createSimplifyInstPass());
-
   // Propagate precise attribute.
   MPM.add(createDxilPrecisePropagatePass());
 
+  MPM.add(createSimplifyInstPass());
+
   // scalarize vector to scalar
   MPM.add(createScalarizerPass());
 

+ 0 - 1
tools/clang/test/CodeGenHLSL/Samples/DX11/FluidCS11_ForceCS_Grid.hlsl

@@ -6,7 +6,6 @@
 // CHECK: FMin
 // CHECK: IMax
 // CHECK: IMin
-// CHECK: IMad
 // CHECK: dot2
 // CHECK: Log
 // CHECK: Exp

+ 22 - 0
tools/clang/test/CodeGenHLSL/quick-test/mad_opt.hlsl

@@ -0,0 +1,22 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// Make sure no intrinsic for mad.
+// CHECK-NOT: dx.op.tertiary
+
+
+// Make sure a, c, e are not used.
+// CHECK-NOT: dx.op.loadInput.f32(i32 4, i32 0
+// CHECK-NOT: dx.op.loadInput.i32(i32 4, i32 2
+// CHECK-NOT: dx.op.loadInput.i32(i32 4, i32 4
+
+// Make sure b, d, f are used.
+// CHECK: dx.op.loadInput.i32(i32 4, i32 5
+// CHECK: dx.op.loadInput.i32(i32 4, i32 3
+// CHECK: dx.op.loadInput.f32(i32 4, i32 1
+
+// CHECK: fadd fast
+// CHECK: fadd fast
+
+float main(float a : A, float b :B, int c : C, int d :D, uint e :E, uint f :F) : SV_Target {
+  return mad(a, 0, b) + mad(0, c, d) + mad(e, 0, f);
+}

+ 16 - 0
tools/clang/test/CodeGenHLSL/quick-test/mad_opt2.hlsl

@@ -0,0 +1,16 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+
+// Make sure no intrinsic for mad.
+// CHECK-NOT: dx.op.tertiary
+// Make sure have 3 fast float add and 2 int add.
+// CHECK: add i32
+// CHECK: add i32
+// CHECK: fadd fast
+// CHECK: fadd fast
+// CHECK: fadd fast
+
+
+float main(float a : A, float b :B, int c : C, int d :D, uint e :E, uint f :F) : SV_Target {
+  return mad(a, 1, b) + mad(1, c, d) + mad(e, 1, f);
+}

+ 10 - 0
tools/clang/test/CodeGenHLSL/quick-test/mad_opt3.hlsl

@@ -0,0 +1,10 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+
+// Make sure mad is not optimized when has precise.
+// CHECK: dx.op.tertiary.f32
+
+float main(float a : A, float b :B) : SV_Target {
+  precise float t = mad(a, 0, b);
+  return t;
+}