|
@@ -421,19 +421,17 @@ static Constant *ConstantFoldQuaternaryIntInstrinsic(OP::OpCode opcode, Type *Ty
|
|
|
return ConstantInt::get(Ty, result);
|
|
return ConstantInt::get(Ty, result);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-// Return true if opcode is for a dot operation.
|
|
|
|
|
-static bool IsDotOpcode(OP::OpCode opcode) {
|
|
|
|
|
- return opcode == OP::OpCode::Dot2
|
|
|
|
|
- || opcode == OP::OpCode::Dot3
|
|
|
|
|
- || opcode == OP::OpCode::Dot4;
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
// Top level function to constant fold floating point intrinsics.
|
|
// Top level function to constant fold floating point intrinsics.
|
|
|
static Constant *ConstantFoldFPIntrinsic(OP::OpCode opcode, Type *Ty, const DxilIntrinsicOperands &IntrinsicOperands) {
|
|
static Constant *ConstantFoldFPIntrinsic(OP::OpCode opcode, Type *Ty, const DxilIntrinsicOperands &IntrinsicOperands) {
|
|
|
if (!Ty->isHalfTy() && !Ty->isFloatTy() && !Ty->isDoubleTy())
|
|
if (!Ty->isHalfTy() && !Ty->isFloatTy() && !Ty->isDoubleTy())
|
|
|
return nullptr;
|
|
return nullptr;
|
|
|
|
|
|
|
|
- if (IntrinsicOperands.Size() == 1) {
|
|
|
|
|
|
|
+ OP::OpCodeClass opClass = OP::GetOpCodeClass(opcode);
|
|
|
|
|
+
|
|
|
|
|
+ switch (opClass) {
|
|
|
|
|
+ default: break;
|
|
|
|
|
+ case OP::OpCodeClass::Unary: {
|
|
|
|
|
+ assert(IntrinsicOperands.Size() == 1);
|
|
|
ConstantFP *Op = IntrinsicOperands.GetConstantFloat(0);
|
|
ConstantFP *Op = IntrinsicOperands.GetConstantFloat(0);
|
|
|
|
|
|
|
|
if (!IsValidOp(Op))
|
|
if (!IsValidOp(Op))
|
|
@@ -441,7 +439,8 @@ static Constant *ConstantFoldFPIntrinsic(OP::OpCode opcode, Type *Ty, const Dxil
|
|
|
|
|
|
|
|
return ConstantFoldUnaryFPIntrinsic(opcode, Ty, Op);
|
|
return ConstantFoldUnaryFPIntrinsic(opcode, Ty, Op);
|
|
|
}
|
|
}
|
|
|
- else if (IntrinsicOperands.Size() == 2) {
|
|
|
|
|
|
|
+ case OP::OpCodeClass::Binary: {
|
|
|
|
|
+ assert(IntrinsicOperands.Size() == 2);
|
|
|
ConstantFP *Op1 = IntrinsicOperands.GetConstantFloat(0);
|
|
ConstantFP *Op1 = IntrinsicOperands.GetConstantFloat(0);
|
|
|
ConstantFP *Op2 = IntrinsicOperands.GetConstantFloat(1);
|
|
ConstantFP *Op2 = IntrinsicOperands.GetConstantFloat(1);
|
|
|
|
|
|
|
@@ -450,7 +449,8 @@ static Constant *ConstantFoldFPIntrinsic(OP::OpCode opcode, Type *Ty, const Dxil
|
|
|
|
|
|
|
|
return ConstantFoldBinaryFPIntrinsic(opcode, Ty, Op1, Op2);
|
|
return ConstantFoldBinaryFPIntrinsic(opcode, Ty, Op1, Op2);
|
|
|
}
|
|
}
|
|
|
- else if (IntrinsicOperands.Size() == 3) {
|
|
|
|
|
|
|
+ case OP::OpCodeClass::Tertiary: {
|
|
|
|
|
+ assert(IntrinsicOperands.Size() == 3);
|
|
|
ConstantFP *Op1 = IntrinsicOperands.GetConstantFloat(0);
|
|
ConstantFP *Op1 = IntrinsicOperands.GetConstantFloat(0);
|
|
|
ConstantFP *Op2 = IntrinsicOperands.GetConstantFloat(1);
|
|
ConstantFP *Op2 = IntrinsicOperands.GetConstantFloat(1);
|
|
|
ConstantFP *Op3 = IntrinsicOperands.GetConstantFloat(2);
|
|
ConstantFP *Op3 = IntrinsicOperands.GetConstantFloat(2);
|
|
@@ -460,7 +460,9 @@ static Constant *ConstantFoldFPIntrinsic(OP::OpCode opcode, Type *Ty, const Dxil
|
|
|
|
|
|
|
|
return ConstantFoldTernaryFPIntrinsic(opcode, Ty, Op1, Op2, Op3);
|
|
return ConstantFoldTernaryFPIntrinsic(opcode, Ty, Op1, Op2, Op3);
|
|
|
}
|
|
}
|
|
|
- else if (IsDotOpcode(opcode)) {
|
|
|
|
|
|
|
+ case OP::OpCodeClass::Dot2:
|
|
|
|
|
+ case OP::OpCodeClass::Dot3:
|
|
|
|
|
+ case OP::OpCodeClass::Dot4:
|
|
|
return ConstantFoldDot(opcode, Ty, IntrinsicOperands);
|
|
return ConstantFoldDot(opcode, Ty, IntrinsicOperands);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -472,14 +474,21 @@ static Constant *ConstantFoldIntIntrinsic(OP::OpCode opcode, Type *Ty, const Dxi
|
|
|
if (Ty->getScalarSizeInBits() > (sizeof(int64_t) * CHAR_BIT))
|
|
if (Ty->getScalarSizeInBits() > (sizeof(int64_t) * CHAR_BIT))
|
|
|
return nullptr;
|
|
return nullptr;
|
|
|
|
|
|
|
|
- if (IntrinsicOperands.Size() == 1) {
|
|
|
|
|
|
|
+ OP::OpCodeClass opClass = OP::GetOpCodeClass(opcode);
|
|
|
|
|
+
|
|
|
|
|
+ switch (opClass) {
|
|
|
|
|
+ default: break;
|
|
|
|
|
+ case OP::OpCodeClass::Unary:
|
|
|
|
|
+ case OP::OpCodeClass::UnaryBits: {
|
|
|
|
|
+ assert(IntrinsicOperands.Size() == 1);
|
|
|
ConstantInt *Op = IntrinsicOperands.GetConstantInt(0);
|
|
ConstantInt *Op = IntrinsicOperands.GetConstantInt(0);
|
|
|
if (!Op)
|
|
if (!Op)
|
|
|
return nullptr;
|
|
return nullptr;
|
|
|
|
|
|
|
|
return ConstantFoldUnaryIntIntrinsic(opcode, Ty, Op);
|
|
return ConstantFoldUnaryIntIntrinsic(opcode, Ty, Op);
|
|
|
}
|
|
}
|
|
|
- else if (IntrinsicOperands.Size() == 2) {
|
|
|
|
|
|
|
+ case OP::OpCodeClass::Binary: {
|
|
|
|
|
+ assert(IntrinsicOperands.Size() == 2);
|
|
|
ConstantInt *Op1 = IntrinsicOperands.GetConstantInt(0);
|
|
ConstantInt *Op1 = IntrinsicOperands.GetConstantInt(0);
|
|
|
ConstantInt *Op2 = IntrinsicOperands.GetConstantInt(1);
|
|
ConstantInt *Op2 = IntrinsicOperands.GetConstantInt(1);
|
|
|
if (!Op1 || !Op2)
|
|
if (!Op1 || !Op2)
|
|
@@ -487,7 +496,8 @@ static Constant *ConstantFoldIntIntrinsic(OP::OpCode opcode, Type *Ty, const Dxi
|
|
|
|
|
|
|
|
return ConstantFoldBinaryIntIntrinsic(opcode, Ty, Op1, Op2);
|
|
return ConstantFoldBinaryIntIntrinsic(opcode, Ty, Op1, Op2);
|
|
|
}
|
|
}
|
|
|
- else if (IntrinsicOperands.Size() == 3) {
|
|
|
|
|
|
|
+ case OP::OpCodeClass::Tertiary: {
|
|
|
|
|
+ assert(IntrinsicOperands.Size() == 3);
|
|
|
ConstantInt *Op1 = IntrinsicOperands.GetConstantInt(0);
|
|
ConstantInt *Op1 = IntrinsicOperands.GetConstantInt(0);
|
|
|
ConstantInt *Op2 = IntrinsicOperands.GetConstantInt(1);
|
|
ConstantInt *Op2 = IntrinsicOperands.GetConstantInt(1);
|
|
|
ConstantInt *Op3 = IntrinsicOperands.GetConstantInt(2);
|
|
ConstantInt *Op3 = IntrinsicOperands.GetConstantInt(2);
|
|
@@ -496,7 +506,8 @@ static Constant *ConstantFoldIntIntrinsic(OP::OpCode opcode, Type *Ty, const Dxi
|
|
|
|
|
|
|
|
return ConstantFoldTernaryIntIntrinsic(opcode, Ty, Op1, Op2, Op3);
|
|
return ConstantFoldTernaryIntIntrinsic(opcode, Ty, Op1, Op2, Op3);
|
|
|
}
|
|
}
|
|
|
- else if (IntrinsicOperands.Size() == 4) {
|
|
|
|
|
|
|
+ case OP::OpCodeClass::Quaternary: {
|
|
|
|
|
+ assert(IntrinsicOperands.Size() == 4);
|
|
|
ConstantInt *Op1 = IntrinsicOperands.GetConstantInt(0);
|
|
ConstantInt *Op1 = IntrinsicOperands.GetConstantInt(0);
|
|
|
ConstantInt *Op2 = IntrinsicOperands.GetConstantInt(1);
|
|
ConstantInt *Op2 = IntrinsicOperands.GetConstantInt(1);
|
|
|
ConstantInt *Op3 = IntrinsicOperands.GetConstantInt(2);
|
|
ConstantInt *Op3 = IntrinsicOperands.GetConstantInt(2);
|
|
@@ -506,6 +517,8 @@ static Constant *ConstantFoldIntIntrinsic(OP::OpCode opcode, Type *Ty, const Dxi
|
|
|
|
|
|
|
|
return ConstantFoldQuaternaryIntInstrinsic(opcode, Ty, Op1, Op2, Op3, Op4);
|
|
return ConstantFoldQuaternaryIntInstrinsic(opcode, Ty, Op1, Op2, Op3, Op4);
|
|
|
}
|
|
}
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
return nullptr;
|
|
return nullptr;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -535,6 +548,8 @@ bool hlsl::CanConstantFoldCallTo(const Function *F) {
|
|
|
return false;
|
|
return false;
|
|
|
|
|
|
|
|
// Check match using startswith to get all overloads.
|
|
// Check match using startswith to get all overloads.
|
|
|
|
|
+ // We cannot use the opcode class here because constant folding
|
|
|
|
|
+ // may run without a DxilModule available.
|
|
|
StringRef Name = F->getName();
|
|
StringRef Name = F->getName();
|
|
|
if (Name.startswith("dx.op.unary"))
|
|
if (Name.startswith("dx.op.unary"))
|
|
|
return true;
|
|
return true;
|