|
@@ -0,0 +1,170 @@
|
|
|
+//===-- DxilSimplify.cpp - Fold dxil intrinsics into constants -----===//
|
|
|
+//
|
|
|
+// The LLVM Compiler Infrastructure
|
|
|
+//
|
|
|
+// This file is distributed under the University of Illinois Open Source
|
|
|
+// License. See LICENSE.TXT for details.
|
|
|
+//
|
|
|
+// Copyright (C) Microsoft Corporation. All rights reserved.
|
|
|
+//
|
|
|
+//===----------------------------------------------------------------------===//
|
|
|
+//
|
|
|
+//
|
|
|
+//===----------------------------------------------------------------------===//
|
|
|
+
|
|
|
+// simplify dxil op like mad 0, a, b->b.
|
|
|
+
|
|
|
+#include "llvm/Analysis/InstructionSimplify.h"
|
|
|
+#include "llvm/IR/Constants.h"
|
|
|
+#include "llvm/IR/Function.h"
|
|
|
+#include "llvm/IR/Instruction.h"
|
|
|
+#include "llvm/IR/Module.h"
|
|
|
+#include "llvm/IR/IRBuilder.h"
|
|
|
+
|
|
|
+#include "dxc/HLSL/DxilModule.h"
|
|
|
+#include "dxc/HLSL/DxilOperations.h"
|
|
|
+#include "llvm/Analysis/DxilConstantFolding.h"
|
|
|
+#include "llvm/Analysis/DxilSimplify.h"
|
|
|
+
|
|
|
+using namespace llvm;
|
|
|
+using namespace hlsl;
|
|
|
+
|
|
|
+namespace {
|
|
|
+DXIL::OpCode GetOpcode(Value *opArg) {
|
|
|
+ if (ConstantInt *ci = dyn_cast<ConstantInt>(opArg)) {
|
|
|
+ uint64_t opcode = ci->getLimitedValue();
|
|
|
+ if (opcode < static_cast<uint64_t>(OP::OpCode::NumOpCodes)) {
|
|
|
+ return static_cast<OP::OpCode>(opcode);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return DXIL::OpCode::NumOpCodes;
|
|
|
+}
|
|
|
+} // namespace
|
|
|
+
|
|
|
+namespace hlsl {
|
|
|
+bool CanSimplify(const llvm::Function *F) {
|
|
|
+ // Only simplify dxil functions when we have a valid dxil module.
|
|
|
+ if (!F->getParent()->HasDxilModule()) {
|
|
|
+ assert(!OP::IsDxilOpFunc(F) && "dx.op function with no dxil module?");
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ // Lookup opcode class in dxil module. Set default value to invalid class.
|
|
|
+ OP::OpCodeClass opClass = OP::OpCodeClass::NumOpClasses;
|
|
|
+ const bool found =
|
|
|
+ F->getParent()->GetDxilModule().GetOP()->GetOpCodeClass(F, opClass);
|
|
|
+
|
|
|
+ // Return true for those dxil operation classes we can simplify.
|
|
|
+ if (found) {
|
|
|
+ switch (opClass) {
|
|
|
+ default:
|
|
|
+ break;
|
|
|
+ case OP::OpCodeClass::Tertiary:
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return false;
|
|
|
+}
|
|
|
+
|
|
|
+/// \brief Given a function and set of arguments, see if we can fold the
|
|
|
+/// result as dxil operation.
|
|
|
+///
|
|
|
+/// If this call could not be simplified returns null.
|
|
|
+Value *SimplifyDxilCall(llvm::Function *F, ArrayRef<Value *> Args,
|
|
|
+ llvm::Instruction *I) {
|
|
|
+ if (!F->getParent()->HasDxilModule()) {
|
|
|
+ assert(!OP::IsDxilOpFunc(F) && "dx.op function with no dxil module?");
|
|
|
+ return nullptr;
|
|
|
+ }
|
|
|
+
|
|
|
+ DxilModule &DM = F->getParent()->GetDxilModule();
|
|
|
+ // Skip precise.
|
|
|
+ if (DM.IsPrecise(I))
|
|
|
+ return nullptr;
|
|
|
+
|
|
|
+ // Lookup opcode class in dxil module. Set default value to invalid class.
|
|
|
+ OP::OpCodeClass opClass = OP::OpCodeClass::NumOpClasses;
|
|
|
+ const bool found = DM.GetOP()->GetOpCodeClass(F, opClass);
|
|
|
+ if (!found)
|
|
|
+ return nullptr;
|
|
|
+
|
|
|
+ DXIL::OpCode opcode = GetOpcode(Args[0]);
|
|
|
+ if (opcode == DXIL::OpCode::NumOpCodes)
|
|
|
+ return nullptr;
|
|
|
+
|
|
|
+ if (CanConstantFoldCallTo(F)) {
|
|
|
+ bool bAllConstant = true;
|
|
|
+ SmallVector<Constant *, 4> ConstantArgs;
|
|
|
+ ConstantArgs.reserve(Args.size());
|
|
|
+ for (Value *V : Args) {
|
|
|
+ Constant *C = dyn_cast<Constant>(V);
|
|
|
+ if (!C) {
|
|
|
+ bAllConstant = false;
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ ConstantArgs.push_back(C);
|
|
|
+ }
|
|
|
+
|
|
|
+ if (bAllConstant)
|
|
|
+ return hlsl::ConstantFoldScalarCall(F->getName(), F->getReturnType(),
|
|
|
+ ConstantArgs);
|
|
|
+ }
|
|
|
+
|
|
|
+ switch (opcode) {
|
|
|
+ default:
|
|
|
+ return nullptr;
|
|
|
+ case DXIL::OpCode::FMad: {
|
|
|
+ Value *op0 = Args[DXIL::OperandIndex::kTrinarySrc0OpIdx];
|
|
|
+ Value *op2 = Args[DXIL::OperandIndex::kTrinarySrc2OpIdx];
|
|
|
+ Constant *zero = ConstantFP::get(op0->getType(), 0);
|
|
|
+ if (op0 == zero)
|
|
|
+ return op2;
|
|
|
+ Value *op1 = Args[DXIL::OperandIndex::kTrinarySrc1OpIdx];
|
|
|
+ if (op1 == zero)
|
|
|
+ return op2;
|
|
|
+
|
|
|
+ Constant *one = ConstantFP::get(op0->getType(), 1);
|
|
|
+ if (op0 == one) {
|
|
|
+ IRBuilder<> Builder(I);
|
|
|
+ llvm::FastMathFlags FMF;
|
|
|
+ FMF.setUnsafeAlgebraHLSL();
|
|
|
+ Builder.SetFastMathFlags(FMF);
|
|
|
+ return Builder.CreateFAdd(op1, op2);
|
|
|
+ }
|
|
|
+ if (op1 == one) {
|
|
|
+ IRBuilder<> Builder(I);
|
|
|
+ llvm::FastMathFlags FMF;
|
|
|
+ FMF.setUnsafeAlgebraHLSL();
|
|
|
+ Builder.SetFastMathFlags(FMF);
|
|
|
+
|
|
|
+ return Builder.CreateFAdd(op0, op2);
|
|
|
+ }
|
|
|
+ return nullptr;
|
|
|
+ } break;
|
|
|
+ case DXIL::OpCode::IMad:
|
|
|
+ case DXIL::OpCode::UMad: {
|
|
|
+ Value *op0 = Args[DXIL::OperandIndex::kTrinarySrc0OpIdx];
|
|
|
+ Value *op2 = Args[DXIL::OperandIndex::kTrinarySrc2OpIdx];
|
|
|
+ Constant *zero = ConstantInt::get(op0->getType(), 0);
|
|
|
+ if (op0 == zero)
|
|
|
+ return op2;
|
|
|
+ Value *op1 = Args[DXIL::OperandIndex::kTrinarySrc1OpIdx];
|
|
|
+ if (op1 == zero)
|
|
|
+ return op2;
|
|
|
+
|
|
|
+ Constant *one = ConstantInt::get(op0->getType(), 1);
|
|
|
+ if (op0 == one) {
|
|
|
+ IRBuilder<> Builder(I);
|
|
|
+ return Builder.CreateAdd(op1, op2);
|
|
|
+ }
|
|
|
+ if (op1 == one) {
|
|
|
+ IRBuilder<> Builder(I);
|
|
|
+ return Builder.CreateAdd(op0, op2);
|
|
|
+ }
|
|
|
+ return nullptr;
|
|
|
+ } break;
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+} // namespace hlsl
|