BlendTreeFloatMath1NodeTests.cpp 16 KB


  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <AzCore/Math/MathUtils.h>
  9. #include <AzCore/Math/Random.h>
  10. #include <EMotionFX/Source/AnimGraph.h>
  11. #include <EMotionFX/Source/AnimGraphBindPoseNode.h>
  12. #include <EMotionFX/Source/AnimGraphNode.h>
  13. #include <EMotionFX/Source/AnimGraphStateMachine.h>
  14. #include <EMotionFX/Source/BlendTree.h>
  15. #include <EMotionFX/Source/BlendTreeBlend2Node.h>
  16. #include <EMotionFX/Source/BlendTreeFloatConstantNode.h>
  17. #include <EMotionFX/Source/BlendTreeFloatMath1Node.h>
  18. #include <EMotionFX/Source/BlendTreeParameterNode.h>
  19. #include <EMotionFX/Source/EMotionFXManager.h>
  20. #include <EMotionFX/Source/Parameter/BoolParameter.h>
  21. #include <EMotionFX/Source/Parameter/FloatSliderParameter.h>
  22. #include <EMotionFX/Source/Parameter/IntSliderParameter.h>
  23. #include <Tests/AnimGraphFixture.h>
  24. namespace EMotionFX
  25. {
  26. struct BlendTreeFloatMath1NodeTestData
  27. {
  28. std::vector<float> m_xInputFloat;
  29. std::vector<int> m_xInputInt;
  30. std::vector<bool> m_xInputBool;
  31. };
  32. std::vector<BlendTreeFloatMath1NodeTestData> blendTreeFloatMath1NodeTestData
  33. {
  34. {
  35. // TODO: MCore RandF function does not handle extreme values properly
  36. // eg. MCore::Math::RandF(0, FLT_MAX) returns inf
  37. {1000.3f, -1000.3f, 0.1f, -1.2f, 1.2f},
  38. {1000, -1000, 0, -1, 1},
  39. {true, false}
  40. }
  41. };
  42. class BlendTreeFloatMath1NodeFixture
  43. : public AnimGraphFixture
  44. , public ::testing::WithParamInterface<BlendTreeFloatMath1NodeTestData>
  45. {
  46. public:
  47. void ConstructGraph() override
  48. {
  49. AnimGraphFixture::ConstructGraph();
  50. m_param = GetParam();
  51. m_blendTreeAnimGraph = AnimGraphFactory::Create<OneBlendTreeNodeAnimGraph>();
  52. m_rootStateMachine = m_blendTreeAnimGraph->GetRootStateMachine();
  53. m_blendTree = m_blendTreeAnimGraph->GetBlendTreeNode();
  54. AddParameter<FloatSliderParameter>("FloatParam", 0.0f);
  55. AddParameter<BoolParameter>("BoolParam", false);
  56. AddParameter<IntSliderParameter>("IntParam", 0);
  57. /*
  58. +------------------+
  59. | |
  60. | bindPoseNode |
  61. | | +------------------+ +------------------+
  62. +------------------+-->+ | | |
  63. | blend2Node +-->+ finalNode |
  64. +------------------+ +------------------+ | | | |
  65. | | | +-->+------------------+ +------------------+
  66. | m_paramNode +-->+ m_floatMath1Node |
  67. | | | |
  68. +------------------+ +------------------+
  69. */
  70. BlendTreeFinalNode* finalNode = aznew BlendTreeFinalNode();
  71. m_blendTree->AddChildNode(finalNode);
  72. AnimGraphBindPoseNode* bindPoseNode = aznew AnimGraphBindPoseNode();
  73. m_blendTree->AddChildNode(bindPoseNode);
  74. BlendTreeBlend2Node* blend2Node = aznew BlendTreeBlend2Node();
  75. m_blendTree->AddChildNode(blend2Node);
  76. m_floatMath1Node = aznew BlendTreeFloatMath1Node();
  77. m_blendTree->AddChildNode(m_floatMath1Node);
  78. m_paramNode = aznew BlendTreeParameterNode();
  79. m_blendTree->AddChildNode(m_paramNode);
  80. // Connect the nodes.
  81. blend2Node->AddConnection(bindPoseNode, AnimGraphBindPoseNode::PORTID_OUTPUT_POSE, BlendTreeBlend2Node::INPUTPORT_POSE_A);
  82. blend2Node->AddConnection(bindPoseNode, AnimGraphBindPoseNode::PORTID_OUTPUT_POSE, BlendTreeBlend2Node::INPUTPORT_POSE_B);
  83. blend2Node->AddConnection(m_floatMath1Node, BlendTreeFloatMath1Node::OUTPUTPORT_RESULT, BlendTreeBlend2Node::INPUTPORT_WEIGHT);
  84. finalNode->AddConnection(blend2Node, BlendTreeBlend2Node::PORTID_OUTPUT_POSE, BlendTreeFinalNode::PORTID_INPUT_POSE);
  85. m_blendTreeAnimGraph->InitAfterLoading();
  86. }
  87. template <class paramType, class inputType>
  88. void TestInput(const AZStd::string& paramName, std::vector<inputType> xInputs)
  89. {
  90. BlendTreeConnection* connection = m_floatMath1Node->AddConnection(m_paramNode,
  91. static_cast<uint16>(m_paramNode->FindOutputPortByName(paramName)->m_portId), BlendTreeFloatMath1Node::PORTID_INPUT_X);
  92. for (inputType i : xInputs)
  93. {
  94. // Get and set parameter value to different test data inputs
  95. const AZ::Outcome<size_t> parameterIndex = m_animGraphInstance->FindParameterIndex(paramName);
  96. MCore::Attribute* param = m_animGraphInstance->GetParameterValue(static_cast<AZ::u32>(parameterIndex.GetValue()));
  97. paramType* typeParam = static_cast<paramType*>(param);
  98. typeParam->SetValue(i);
  99. for (AZ::u8 j = 0; j < BlendTreeFloatMath1Node::MATHFUNCTION_NUMFUNCTIONS; j++)
  100. {
  101. // Test input with all 26 math functions
  102. const BlendTreeFloatMath1Node::EMathFunction eMathFunc = static_cast<BlendTreeFloatMath1Node::EMathFunction>(j);
  103. m_floatMath1Node->SetMathFunction(eMathFunc);
  104. GetEMotionFX().Update(1.0f / 60.0f);
  105. const float actualOutput = m_floatMath1Node->GetOutputFloat(m_animGraphInstance,
  106. BlendTreeFloatMath1Node::OUTPUTPORT_RESULT)->GetValue();
  107. const float expectedOutput = CalculateMathFunctionOutput(eMathFunc, static_cast<float>(i));
  108. // Special cases for random float where float equal is not suitable
  109. // If actual and expected outputs are both NaN, then they should be considered same
  110. if (eMathFunc == BlendTreeFloatMath1Node::MATHFUNCTION_RANDOMFLOAT)
  111. {
  112. EXPECT_TRUE(RandomFloatIsInRange(actualOutput, 0, static_cast<float>(i))) << "Random float is not in range.";
  113. continue;
  114. }
  115. if (AZStd::isnan(actualOutput) && AZStd::isnan(expectedOutput))
  116. {
  117. continue;
  118. }
  119. if (AZStd::isinf(actualOutput) && AZStd::isinf(expectedOutput))
  120. {
  121. continue;
  122. }
  123. EXPECT_NEAR(actualOutput, expectedOutput, 0.004f) << "Actual and expected outputs does not match.";
  124. }
  125. }
  126. m_floatMath1Node->RemoveConnection(connection);
  127. }
  128. void SetUp() override
  129. {
  130. AnimGraphFixture::SetUp();
  131. m_animGraphInstance->Destroy();
  132. m_animGraphInstance = m_blendTreeAnimGraph->GetAnimGraphInstance(m_actorInstance, m_motionSet);
  133. }
  134. protected:
  135. BlendTree* m_blendTree = nullptr;
  136. BlendTreeFloatMath1Node* m_floatMath1Node = nullptr;
  137. BlendTreeFloatMath1NodeTestData m_param;
  138. BlendTreeParameterNode* m_paramNode = nullptr;
  139. private:
  140. bool RandomFloatIsInRange(float randomFloat, float bound1, float bound2)
  141. {
  142. if (bound1 > bound2)
  143. {
  144. return (randomFloat - bound2) <= (bound1 - bound2);
  145. }
  146. return (randomFloat - bound1) <= (bound2 - bound1);
  147. }
  148. template<class ParameterType, class ValueType>
  149. void AddParameter(const AZStd::string name, const ValueType& defaultValue)
  150. {
  151. ParameterType* parameter = aznew ParameterType();
  152. parameter->SetName(name);
  153. parameter->SetDefaultValue(defaultValue);
  154. m_blendTreeAnimGraph->AddParameter(parameter);
  155. }
  156. float CalculateMathFunctionOutput(BlendTreeFloatMath1Node::EMathFunction mathFunction, float input)
  157. {
  158. switch (mathFunction)
  159. {
  160. case BlendTreeFloatMath1Node::MATHFUNCTION_SIN:
  161. return CalculateSin(input);
  162. case BlendTreeFloatMath1Node::MATHFUNCTION_COS:
  163. return CalculateCos(input);
  164. case BlendTreeFloatMath1Node::MATHFUNCTION_TAN:
  165. return CalculateTan(input);
  166. case BlendTreeFloatMath1Node::MATHFUNCTION_SQR:
  167. return CalculateSqr(input);
  168. case BlendTreeFloatMath1Node::MATHFUNCTION_SQRT:
  169. return CalculateSqrt(input);
  170. case BlendTreeFloatMath1Node::MATHFUNCTION_ABS:
  171. return CalculateAbs(input);
  172. case BlendTreeFloatMath1Node::MATHFUNCTION_FLOOR:
  173. return CalculateFloor(input);
  174. case BlendTreeFloatMath1Node::MATHFUNCTION_CEIL:
  175. return CalculateCeil(input);
  176. case BlendTreeFloatMath1Node::MATHFUNCTION_ONEOVERINPUT:
  177. return CalculateOneOverInput(input);
  178. case BlendTreeFloatMath1Node::MATHFUNCTION_INVSQRT:
  179. return CalculateInvSqrt(input);
  180. case BlendTreeFloatMath1Node::MATHFUNCTION_LOG:
  181. return CalculateLog(input);
  182. case BlendTreeFloatMath1Node::MATHFUNCTION_LOG10:
  183. return CalculateLog10(input);
  184. case BlendTreeFloatMath1Node::MATHFUNCTION_EXP:
  185. return CalculateExp(input);
  186. case BlendTreeFloatMath1Node::MATHFUNCTION_FRACTION:
  187. return CalculateFraction(input);
  188. case BlendTreeFloatMath1Node::MATHFUNCTION_SIGN:
  189. return CalculateSign(input);
  190. case BlendTreeFloatMath1Node::MATHFUNCTION_ISPOSITIVE:
  191. return CalculateIsPositive(input);
  192. case BlendTreeFloatMath1Node::MATHFUNCTION_ISNEGATIVE:
  193. return CalculateIsNegative(input);
  194. case BlendTreeFloatMath1Node::MATHFUNCTION_ISNEARZERO:
  195. return CalculateIsNearZero(input);
  196. case BlendTreeFloatMath1Node::MATHFUNCTION_RANDOMFLOAT:
  197. return 0.0f;
  198. case BlendTreeFloatMath1Node::MATHFUNCTION_RADTODEG:
  199. return CalculateRadToDeg(input);
  200. case BlendTreeFloatMath1Node::MATHFUNCTION_DEGTORAD:
  201. return CalculateDegToRad(input);
  202. case BlendTreeFloatMath1Node::MATHFUNCTION_SMOOTHSTEP:
  203. return CalculateSmoothStep(input);
  204. case BlendTreeFloatMath1Node::MATHFUNCTION_ACOS:
  205. return CalculateACos(input);
  206. case BlendTreeFloatMath1Node::MATHFUNCTION_ASIN:
  207. return CalculateASin(input);
  208. case BlendTreeFloatMath1Node::MATHFUNCTION_ATAN:
  209. return CalculateATan(input);
  210. case BlendTreeFloatMath1Node::MATHFUNCTION_NEGATE:
  211. return CalculateNegate(input);
  212. default:
  213. AZ_Assert(false, "EMotionFX: Math function unknown.");
  214. return 0.0f;
  215. }
  216. }
  217. //-----------------------------------------------
  218. // The math functions
  219. //-----------------------------------------------
  220. float CalculateSin(float input) { return sin(input); }
  221. float CalculateCos(float input) { return cos(input); }
  222. float CalculateTan(float input) { return tan(input); }
  223. float CalculateSqr(float input) { return (input * input); }
  224. float CalculateSqrt(float input)
  225. {
  226. if (input > AZ::Constants::FloatEpsilon)
  227. {
  228. return sqrt(input);
  229. }
  230. return 0.0f;
  231. }
  232. float CalculateAbs(float input) { return abs(input); }
  233. float CalculateFloor(float input) { return floor(input); }
  234. float CalculateCeil(float input) { return ceil(input); }
  235. float CalculateOneOverInput(float input)
  236. {
  237. if (input > AZ::Constants::FloatEpsilon)
  238. {
  239. return 1.0f / input;
  240. }
  241. return 0.0f;
  242. }
  243. float CalculateInvSqrt(float input)
  244. {
  245. if (input > AZ::Constants::FloatEpsilon)
  246. {
  247. return 1.0f / sqrt(input);
  248. }
  249. return 0.0f;
  250. }
  251. float CalculateLog(float input)
  252. {
  253. if (input > AZ::Constants::FloatEpsilon)
  254. {
  255. return log(input);
  256. }
  257. return 0.0f;
  258. }
  259. float CalculateLog10(float input)
  260. {
  261. if (input > AZ::Constants::FloatEpsilon)
  262. {
  263. return log10f(input);
  264. }
  265. return 0.0f;
  266. }
  267. float CalculateExp(float input) { return exp(input); }
  268. float CalculateFraction(float input) { return AZ::GetMod(input, 1.0f); }
  269. float CalculateSign(float input)
  270. {
  271. if (input < 0.0f)
  272. {
  273. return -1.0f;
  274. }
  275. if (input > 0.0f)
  276. {
  277. return 1.0f;
  278. }
  279. return 0.0f;
  280. }
  281. float CalculateIsPositive(float input)
  282. {
  283. if (input >= 0.0f)
  284. {
  285. return 1.0f;
  286. }
  287. return 0.0f;
  288. }
  289. float CalculateIsNegative(float input)
  290. {
  291. if (input < 0.0f)
  292. {
  293. return 1.0f;
  294. }
  295. return 0.0f;
  296. }
  297. float CalculateIsNearZero(float input)
  298. {
  299. if ((input > -AZ::Constants::FloatEpsilon) && (input < AZ::Constants::FloatEpsilon))
  300. {
  301. return 1.0f;
  302. }
  303. return 0.0f;
  304. }
  305. float CalculateRadToDeg(float input) { return AZ::RadToDeg(input); }
  306. float CalculateDegToRad(float input) { return AZ::DegToRad(input); }
  307. float CalculateSmoothStep(float input)
  308. {
  309. const float f = AZ::GetClamp<float>(input, 0.0f, 1.0f);
  310. const float weight = (1.0f - cos(f * AZ::Constants::Pi)) * 0.5f;;
  311. return 0.0f * (1.0f - weight) + (weight * 1.0f);
  312. }
  313. float CalculateACos(float input) { return acos(input); }
  314. float CalculateASin(float input) { return asin(input); }
  315. float CalculateATan(float input) { return atan(input); }
  316. float CalculateNegate(float input) { return -input; }
  317. };
  318. TEST_P(BlendTreeFloatMath1NodeFixture, NoInput_OutputsCorrectFloatTest)
  319. {
  320. // Testing float math1 node without input node
  321. for (AZ::u8 i = 0; i < BlendTreeFloatMath1Node::MATHFUNCTION_NUMFUNCTIONS; i++)
  322. {
  323. BlendTreeFloatMath1Node::EMathFunction eMathFunc = static_cast<BlendTreeFloatMath1Node::EMathFunction>(i);
  324. m_floatMath1Node->SetMathFunction(eMathFunc);
  325. GetEMotionFX().Update(1.0f / 60.0f);
  326. // Default output should be 0.0f
  327. EXPECT_FLOAT_EQ(m_floatMath1Node->GetOutputFloat(m_animGraphInstance,
  328. BlendTreeFloatMath1Node::OUTPUTPORT_RESULT)->GetValue(), 0.0f) << "Expected Output: 0.0f";
  329. }
  330. };
  331. #if AZ_TRAIT_DISABLE_FAILED_EMOTION_FX_TESTS
  332. TEST_P(BlendTreeFloatMath1NodeFixture, DISABLED_FloatInput_OutputsCorrectFloatTest)
  333. #else
  334. TEST_P(BlendTreeFloatMath1NodeFixture, FloatInput_OutputsCorrectFloatTest)
  335. #endif // AZ_TRAIT_DISABLE_FAILED_EMOTION_FX_TESTS
  336. {
  337. TestInput<MCore::AttributeFloat, float>("FloatParam", m_param.m_xInputFloat);
  338. };
  339. #if AZ_TRAIT_DISABLE_FAILED_EMOTION_FX_TESTS
  340. TEST_P(BlendTreeFloatMath1NodeFixture, DISABLED_IntInput_OutputsCorrectFloatTest)
  341. #else
  342. TEST_P(BlendTreeFloatMath1NodeFixture, IntInput_OutputsCorrectFloatTest)
  343. #endif // AZ_TRAIT_DISABLE_FAILED_EMOTION_FX_TESTS
  344. {
  345. TestInput<MCore::AttributeInt32, int>("IntParam", m_param.m_xInputInt);
  346. };
  347. TEST_P(BlendTreeFloatMath1NodeFixture, BoolInput_OutputsCorrectFloatTest)
  348. {
  349. TestInput<MCore::AttributeBool, bool>("BoolParam", m_param.m_xInputBool);
  350. };
  351. INSTANTIATE_TEST_CASE_P(BlendTreeFloatMath1Node_ValidOutputTests,
  352. BlendTreeFloatMath1NodeFixture,
  353. ::testing::ValuesIn(blendTreeFloatMath1NodeTestData)
  354. );
  355. } // end namespace EMotionFX