123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179 |
- //===-- DxilSimplify.cpp - Fold dxil intrinsics into constants -----===//
- //
- // The LLVM Compiler Infrastructure
- //
- // This file is distributed under the University of Illinois Open Source
- // License. See LICENSE.TXT for details.
- //
- // Copyright (C) Microsoft Corporation. All rights reserved.
- //
- //===----------------------------------------------------------------------===//
- //
- //
- //===----------------------------------------------------------------------===//
- // simplify dxil op like mad 0, a, b->b.
- #include "llvm/Analysis/InstructionSimplify.h"
- #include "llvm/IR/Constants.h"
- #include "llvm/IR/Function.h"
- #include "llvm/IR/Instruction.h"
- #include "llvm/IR/Module.h"
- #include "llvm/IR/IRBuilder.h"
- #include "dxc/DXIL/DxilModule.h"
- #include "dxc/DXIL/DxilOperations.h"
- #include "llvm/Analysis/DxilConstantFolding.h"
- #include "llvm/Analysis/DxilSimplify.h"
- using namespace llvm;
- using namespace hlsl;
- namespace {
- DXIL::OpCode GetOpcode(Value *opArg) {
- if (ConstantInt *ci = dyn_cast<ConstantInt>(opArg)) {
- uint64_t opcode = ci->getLimitedValue();
- if (opcode < static_cast<uint64_t>(OP::OpCode::NumOpCodes)) {
- return static_cast<OP::OpCode>(opcode);
- }
- }
- return DXIL::OpCode::NumOpCodes;
- }
- } // namespace
- namespace hlsl {
- bool CanSimplify(const llvm::Function *F) {
- // Only simplify dxil functions when we have a valid dxil module.
- if (!F->getParent()->HasDxilModule()) {
- assert(!OP::IsDxilOpFunc(F) && "dx.op function with no dxil module?");
- return false;
- }
- if (CanConstantFoldCallTo(F))
- return true;
- // Lookup opcode class in dxil module. Set default value to invalid class.
- OP::OpCodeClass opClass = OP::OpCodeClass::NumOpClasses;
- const bool found =
- F->getParent()->GetDxilModule().GetOP()->GetOpCodeClass(F, opClass);
- // Return true for those dxil operation classes we can simplify.
- if (found) {
- switch (opClass) {
- default:
- break;
- case OP::OpCodeClass::Tertiary:
- return true;
- }
- }
- return false;
- }
- /// \brief Given a function and set of arguments, see if we can fold the
- /// result as dxil operation.
- ///
- /// If this call could not be simplified returns null.
- Value *SimplifyDxilCall(llvm::Function *F, ArrayRef<Value *> Args,
- llvm::Instruction *I,
- bool MayInsert)
- {
- if (!F->getParent()->HasDxilModule()) {
- assert(!OP::IsDxilOpFunc(F) && "dx.op function with no dxil module?");
- return nullptr;
- }
- DxilModule &DM = F->getParent()->GetDxilModule();
- // Skip precise.
- if (DM.IsPrecise(I))
- return nullptr;
- // Lookup opcode class in dxil module. Set default value to invalid class.
- OP::OpCodeClass opClass = OP::OpCodeClass::NumOpClasses;
- const bool found = DM.GetOP()->GetOpCodeClass(F, opClass);
- if (!found)
- return nullptr;
- DXIL::OpCode opcode = GetOpcode(Args[0]);
- if (opcode == DXIL::OpCode::NumOpCodes)
- return nullptr;
- if (CanConstantFoldCallTo(F)) {
- bool bAllConstant = true;
- SmallVector<Constant *, 4> ConstantArgs;
- ConstantArgs.reserve(Args.size());
- for (Value *V : Args) {
- Constant *C = dyn_cast<Constant>(V);
- if (!C) {
- bAllConstant = false;
- break;
- }
- ConstantArgs.push_back(C);
- }
- if (bAllConstant)
- return hlsl::ConstantFoldScalarCall(F->getName(), F->getReturnType(),
- ConstantArgs);
- }
- switch (opcode) {
- default:
- return nullptr;
- case DXIL::OpCode::FMad: {
- Value *op0 = Args[DXIL::OperandIndex::kTrinarySrc0OpIdx];
- Value *op2 = Args[DXIL::OperandIndex::kTrinarySrc2OpIdx];
- Constant *zero = ConstantFP::get(op0->getType(), 0);
- if (op0 == zero)
- return op2;
- Value *op1 = Args[DXIL::OperandIndex::kTrinarySrc1OpIdx];
- if (op1 == zero)
- return op2;
- if (MayInsert) {
- Constant *one = ConstantFP::get(op0->getType(), 1);
- if (op0 == one) {
- IRBuilder<> Builder(I);
- llvm::FastMathFlags FMF;
- FMF.setUnsafeAlgebraHLSL();
- Builder.SetFastMathFlags(FMF);
- return Builder.CreateFAdd(op1, op2);
- }
- if (op1 == one) {
- IRBuilder<> Builder(I);
- llvm::FastMathFlags FMF;
- FMF.setUnsafeAlgebraHLSL();
- Builder.SetFastMathFlags(FMF);
- return Builder.CreateFAdd(op0, op2);
- }
- }
- return nullptr;
- } break;
- case DXIL::OpCode::IMad:
- case DXIL::OpCode::UMad: {
- Value *op0 = Args[DXIL::OperandIndex::kTrinarySrc0OpIdx];
- Value *op2 = Args[DXIL::OperandIndex::kTrinarySrc2OpIdx];
- Constant *zero = ConstantInt::get(op0->getType(), 0);
- if (op0 == zero)
- return op2;
- Value *op1 = Args[DXIL::OperandIndex::kTrinarySrc1OpIdx];
- if (op1 == zero)
- return op2;
- if (MayInsert) {
- Constant *one = ConstantInt::get(op0->getType(), 1);
- if (op0 == one) {
- IRBuilder<> Builder(I);
- return Builder.CreateAdd(op1, op2);
- }
- if (op1 == one) {
- IRBuilder<> Builder(I);
- return Builder.CreateAdd(op0, op2);
- }
- }
- return nullptr;
- } break;
- }
- }
- } // namespace hlsl
|