JointsTrajectoryComponent.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include "JointsTrajectoryComponent.h"
  9. #include <AzCore/Outcome/Outcome.h>
  10. #include <AzCore/Serialization/EditContext.h>
  11. #include <PhysX/ArticulationJointBus.h>
  12. #include <ROS2/Clock/ROS2ClockRequestBus.h>
  13. #include <ROS2/Frame/ROS2FrameComponent.h>
  14. #include <ROS2/ROS2Bus.h>
  15. #include <ROS2/ROS2NamesBus.h>
  16. #include <ROS2/Utilities/ROS2Conversions.h>
  17. #include <ROS2Controllers/Manipulation/JointsManipulationRequests.h>
  18. namespace ROS2Controllers
  19. {
  20. JointsTrajectoryComponent::JointsTrajectoryComponent(const AZStd::string& followTrajectoryActionName)
  21. : m_followTrajectoryActionName(followTrajectoryActionName)
  22. {
  23. }
  24. void JointsTrajectoryComponent::Activate()
  25. {
  26. auto* ros2Frame = GetEntity()->FindComponent<ROS2::ROS2FrameComponent>();
  27. AZ_Assert(ros2Frame, "Missing Frame Component!");
  28. AZStd::string namespacedAction;
  29. ROS2::ROS2NamesRequestBus::BroadcastResult(
  30. namespacedAction,
  31. &ROS2::ROS2NamesRequestBus::Events::GetNamespacedName,
  32. ros2Frame->GetNamespace(),
  33. m_followTrajectoryActionName);
  34. m_followTrajectoryServer = AZStd::make_unique<FollowJointTrajectoryActionServer>(namespacedAction, GetEntityId());
  35. AZ::TickBus::Handler::BusConnect();
  36. JointsTrajectoryRequestBus::Handler::BusConnect(GetEntityId());
  37. ROS2::ROS2ClockRequestBus::BroadcastResult(m_lastTickTimestamp, &ROS2::ROS2ClockRequestBus::Events::GetROSTimestamp);
  38. }
  39. ManipulationJoints& JointsTrajectoryComponent::GetManipulationJoints()
  40. {
  41. if (m_manipulationJoints.empty())
  42. {
  43. JointsManipulationRequestBus::EventResult(m_manipulationJoints, GetEntityId(), &JointsManipulationRequests::GetJoints);
  44. }
  45. return m_manipulationJoints;
  46. }
  47. void JointsTrajectoryComponent::Deactivate()
  48. {
  49. JointsTrajectoryRequestBus::Handler::BusDisconnect();
  50. AZ::TickBus::Handler::BusDisconnect();
  51. m_followTrajectoryServer.reset();
  52. }
  53. void JointsTrajectoryComponent::Reflect(AZ::ReflectContext* context)
  54. {
  55. if (AZ::SerializeContext* serialize = azrtti_cast<AZ::SerializeContext*>(context))
  56. {
  57. serialize->Class<JointsTrajectoryComponent, AZ::Component>()->Version(1)
  58. ->Field("Action name", &JointsTrajectoryComponent::m_followTrajectoryActionName)
  59. ->Field("Check for position errors", &JointsTrajectoryComponent::m_checkForPositionErrors)
  60. ->Field("Joint goal tolerance", &JointsTrajectoryComponent::m_jointPositionTolerance)
  61. ->Field("Check for velocity", &JointsTrajectoryComponent::m_checkForVelocity)
  62. ->Field("Joint velocity tolerance", &JointsTrajectoryComponent::m_jointVelocityTolerance);
  63. if (AZ::EditContext* ec = serialize->GetEditContext())
  64. {
  65. ec->Class<JointsTrajectoryComponent>("JointsTrajectoryComponent", "Component to control a robotic arm using trajectories")
  66. ->ClassElement(AZ::Edit::ClassElements::EditorData, "")
  67. ->Attribute(AZ::Edit::Attributes::AppearsInAddComponentMenu, AZ_CRC("Game"))
  68. ->Attribute(AZ::Edit::Attributes::Category, "ROS2")
  69. ->Attribute(AZ::Edit::Attributes::Icon, "Editor/Icons/Components/JointsTrajectoryComponent.svg")
  70. ->Attribute(AZ::Edit::Attributes::ViewportIcon, "Editor/Icons/Components/Viewport/JointsTrajectoryComponent.svg")
  71. ->DataElement(
  72. AZ::Edit::UIHandlers::Default,
  73. &JointsTrajectoryComponent::m_followTrajectoryActionName,
  74. "Action Name",
  75. "Name the follow trajectory action server to accept movement commands")
  76. ->DataElement(
  77. AZ::Edit::UIHandlers::Default,
  78. &JointsTrajectoryComponent::m_checkForPositionErrors,
  79. "Check for Position Errors",
  80. "If true, check if joints reached the goal position before reporting success")
  81. ->Attribute(AZ::Edit::Attributes::ChangeNotify, AZ::Edit::PropertyRefreshLevels::AttributesAndValues)
  82. ->DataElement(
  83. AZ::Edit::UIHandlers::Default,
  84. &JointsTrajectoryComponent::m_jointPositionTolerance,
  85. "Joint Position Tolerance",
  86. "The threshold for joint position errors to report the goal as reached (one value for all joints, units depend on joint type)")
  87. ->Attribute(AZ::Edit::Attributes::Visibility, &JointsTrajectoryComponent::ShouldCheckForPositionErrors)
  88. ->DataElement(
  89. AZ::Edit::UIHandlers::Default,
  90. &JointsTrajectoryComponent::m_checkForVelocity,
  91. "Check for Velocity",
  92. "If true, check if joints velocity is below threshold before reporting success")
  93. ->Attribute(AZ::Edit::Attributes::ChangeNotify, AZ::Edit::PropertyRefreshLevels::AttributesAndValues)
  94. ->DataElement(
  95. AZ::Edit::UIHandlers::Default,
  96. &JointsTrajectoryComponent::m_jointVelocityTolerance,
  97. "Joint Velocity Tolerance",
  98. "The threshold for joint velocities under which to report the goal as reached (one value for all joints, units depend on joint type)"
  99. )
  100. ->Attribute(AZ::Edit::Attributes::Visibility, &JointsTrajectoryComponent::ShouldCheckForVelocity);
  101. }
  102. }
  103. }
  104. void JointsTrajectoryComponent::GetRequiredServices(AZ::ComponentDescriptor::DependencyArrayType& required)
  105. {
  106. required.push_back(AZ_CRC_CE("ROS2Frame"));
  107. required.push_back(AZ_CRC_CE("JointsManipulationService"));
  108. }
  109. void JointsTrajectoryComponent::GetProvidedServices(AZ::ComponentDescriptor::DependencyArrayType& provided)
  110. {
  111. provided.push_back(AZ_CRC_CE("ManipulatorJointTrajectoryService"));
  112. }
  113. void JointsTrajectoryComponent::GetIncompatibleServices(AZ::ComponentDescriptor::DependencyArrayType& incompatible)
  114. {
  115. incompatible.push_back(AZ_CRC_CE("ManipulatorJointTrajectoryService"));
  116. }
  117. bool JointsTrajectoryComponent::ShouldCheckForPositionErrors()
  118. {
  119. return m_checkForPositionErrors;
  120. }
  121. bool JointsTrajectoryComponent::ShouldCheckForVelocity()
  122. {
  123. return m_checkForVelocity;
  124. }
  125. AZ::Outcome<void, JointsTrajectoryComponent::TrajectoryResult> JointsTrajectoryComponent::StartTrajectoryGoal(
  126. TrajectoryGoalPtr trajectoryGoal)
  127. {
  128. if (m_goalStatus == JointsTrajectoryRequests::TrajectoryActionStatus::Executing)
  129. {
  130. auto result = JointsTrajectoryComponent::TrajectoryResult();
  131. result.error_code = JointsTrajectoryComponent::TrajectoryResult::INVALID_GOAL;
  132. result.error_string = "Another trajectory goal is executing. Wait for completion or cancel it";
  133. return AZ::Failure(result);
  134. }
  135. auto validationResult = ValidateGoal(trajectoryGoal);
  136. if (!validationResult)
  137. {
  138. return validationResult;
  139. }
  140. m_trajectoryGoal = *trajectoryGoal;
  141. ROS2::ROS2ClockRequestBus::BroadcastResult(m_trajectoryExecutionStartTime, &ROS2::ROS2ClockRequestBus::Events::GetROSTimestamp);
  142. m_goalStatus = JointsTrajectoryRequests::TrajectoryActionStatus::Executing;
  143. return AZ::Success();
  144. }
  145. AZ::Outcome<void, JointsTrajectoryComponent::TrajectoryResult> JointsTrajectoryComponent::ValidateGoal(TrajectoryGoalPtr trajectoryGoal)
  146. {
  147. // Check joint names validity
  148. for (const auto& jointName : trajectoryGoal->trajectory.joint_names)
  149. {
  150. AZStd::string azJointName(jointName.c_str());
  151. if (m_manipulationJoints.find(azJointName) == m_manipulationJoints.end())
  152. {
  153. AZ_Printf("JointsTrajectoryComponent", "Trajectory goal is invalid: no joint %s in manipulator", azJointName.c_str());
  154. auto result = JointsTrajectoryComponent::TrajectoryResult();
  155. result.error_code = JointsTrajectoryComponent::TrajectoryResult::INVALID_JOINTS;
  156. result.error_string = std::string(
  157. AZStd::string::format("Trajectory goal is invalid: no joint %s in manipulator", azJointName.c_str()).c_str());
  158. return AZ::Failure(result);
  159. }
  160. }
  161. return AZ::Success();
  162. }
  163. void JointsTrajectoryComponent::UpdateFeedback()
  164. {
  165. auto goalStatus = GetGoalStatus();
  166. if (goalStatus != JointsTrajectoryRequests::TrajectoryActionStatus::Executing)
  167. {
  168. return;
  169. }
  170. auto feedback = std::make_shared<control_msgs::action::FollowJointTrajectory::Feedback>();
  171. trajectory_msgs::msg::JointTrajectoryPoint desiredPoint = m_trajectoryGoal.trajectory.points.front();
  172. trajectory_msgs::msg::JointTrajectoryPoint actualPoint;
  173. size_t jointCount = m_trajectoryGoal.trajectory.joint_names.size();
  174. for (size_t jointIndex = 0; jointIndex < jointCount; jointIndex++)
  175. {
  176. AZStd::string jointName(m_trajectoryGoal.trajectory.joint_names[jointIndex].c_str());
  177. std::string jointNameStdString(jointName.c_str());
  178. feedback->joint_names.push_back(jointNameStdString);
  179. float currentJointPosition;
  180. float currentJointVelocity;
  181. auto& jointInfo = m_manipulationJoints[jointName];
  182. PhysX::ArticulationJointRequestBus::Event(
  183. jointInfo.m_entityComponentIdPair.GetEntityId(),
  184. [&](PhysX::ArticulationJointRequests* articulationJointRequests)
  185. {
  186. currentJointPosition = articulationJointRequests->GetJointPosition(jointInfo.m_axis);
  187. currentJointVelocity = articulationJointRequests->GetJointVelocity(jointInfo.m_axis);
  188. });
  189. actualPoint.positions.push_back(static_cast<double>(currentJointPosition));
  190. actualPoint.velocities.push_back(static_cast<double>(currentJointVelocity));
  191. // Acceleration should also be filled in somehow, or removed from the trajectory altogether.
  192. }
  193. trajectory_msgs::msg::JointTrajectoryPoint currentError;
  194. for (size_t jointIndex = 0; jointIndex < jointCount; jointIndex++)
  195. {
  196. currentError.positions.push_back(actualPoint.positions[jointIndex] - desiredPoint.positions[jointIndex]);
  197. currentError.velocities.push_back(actualPoint.velocities[jointIndex] - desiredPoint.velocities[jointIndex]);
  198. }
  199. feedback->desired = desiredPoint;
  200. feedback->actual = actualPoint;
  201. feedback->error = currentError;
  202. m_followTrajectoryServer->PublishFeedback(feedback);
  203. }
  204. AZ::Outcome<void, AZStd::string> JointsTrajectoryComponent::CancelTrajectoryGoal()
  205. {
  206. m_trajectoryGoal.trajectory.points.clear();
  207. m_goalStatus = JointsTrajectoryRequests::TrajectoryActionStatus::Cancelled;
  208. return AZ::Success();
  209. }
  210. JointsTrajectoryRequests::TrajectoryActionStatus JointsTrajectoryComponent::GetGoalStatus()
  211. {
  212. return m_goalStatus;
  213. }
  214. void JointsTrajectoryComponent::FollowTrajectory(const uint64_t deltaTimeNs)
  215. {
  216. auto goalStatus = GetGoalStatus();
  217. if (goalStatus == JointsTrajectoryRequests::TrajectoryActionStatus::Cancelled)
  218. {
  219. JointsManipulationRequestBus::Event(GetEntityId(), &JointsManipulationRequests::Stop);
  220. auto result = std::make_shared<FollowJointTrajectoryActionServer::FollowJointTrajectory::Result>();
  221. result->error_string = "User Cancelled";
  222. result->error_code = FollowJointTrajectoryActionServer::FollowJointTrajectory::Result::SUCCESSFUL;
  223. m_followTrajectoryServer->CancelGoal(result);
  224. return;
  225. }
  226. if (goalStatus != JointsTrajectoryRequests::TrajectoryActionStatus::Executing)
  227. {
  228. return;
  229. }
  230. if (m_trajectoryGoal.trajectory.points.size() == 0)
  231. { // The manipulator has reached the goal.
  232. AZ_TracePrintf("JointsManipulationComponent", "Goal Concluded: all points reached\n");
  233. auto successResult = std::make_shared<control_msgs::action::FollowJointTrajectory::Result>(); //!< Empty defaults to success.
  234. m_followTrajectoryServer->GoalSuccess(successResult);
  235. m_goalStatus = JointsTrajectoryRequests::TrajectoryActionStatus::Succeeded;
  236. return;
  237. }
  238. auto desiredGoal = m_trajectoryGoal.trajectory.points.front();
  239. rclcpp::Duration targetGoalTime = rclcpp::Duration(desiredGoal.time_from_start); //!< Requested arrival time for trajectory point.
  240. builtin_interfaces::msg::Time timestamp;
  241. ROS2::ROS2ClockRequestBus::BroadcastResult(timestamp, &ROS2::ROS2ClockRequestBus::Events::GetROSTimestamp);
  242. const rclcpp::Time timeNow = rclcpp::Time(timestamp); //!< Current simulation time.
  243. rclcpp::Duration threshold = rclcpp::Duration::from_nanoseconds(1e7);
  244. // Jump to the next point if current simulation time is ahead of timeFromStart
  245. bool canJumpToNextPoint = m_trajectoryExecutionStartTime + targetGoalTime <= timeNow + threshold;
  246. // But if it's the last point, wait until the manipulator reaches it
  247. bool lastTrajectoryPoint = m_trajectoryGoal.trajectory.points.size() == 1;
  248. if (lastTrajectoryPoint)
  249. {
  250. if (m_checkForPositionErrors)
  251. {
  252. canJumpToNextPoint &= CheckIfPositionReachedTolerance(desiredGoal);
  253. }
  254. if (m_checkForVelocity)
  255. {
  256. canJumpToNextPoint &= CheckIfVelocityReachedTolerance();
  257. }
  258. }
  259. if (canJumpToNextPoint)
  260. { // Jump to the next point in the trajectory
  261. m_trajectoryGoal.trajectory.points.erase(m_trajectoryGoal.trajectory.points.begin());
  262. FollowTrajectory(deltaTimeNs);
  263. return;
  264. }
  265. MoveToNextPoint(desiredGoal);
  266. }
  267. bool JointsTrajectoryComponent::CheckIfPositionReachedTolerance(const trajectory_msgs::msg::JointTrajectoryPoint trajectoryPoint)
  268. {
  269. const auto& goalJointNames = m_trajectoryGoal.trajectory.joint_names;
  270. for (int jointIndex = 0; jointIndex < m_trajectoryGoal.trajectory.joint_names.size(); jointIndex++)
  271. { // Check if each joint reached its target position
  272. const AZStd::string_view jointName(goalJointNames[jointIndex].c_str());
  273. AZ::Outcome<float, AZStd::string> currentJointPosition;
  274. JointsManipulationRequestBus::EventResult(
  275. currentJointPosition, GetEntityId(), &JointsManipulationRequests::GetJointPosition, jointName);
  276. if (!currentJointPosition.IsSuccess())
  277. { // If position cannot be obtained, report failure
  278. return false;
  279. }
  280. const float targetPos = trajectoryPoint.positions[jointIndex];
  281. if (!AZ::IsClose(currentJointPosition.GetValue(), targetPos, m_jointPositionTolerance))
  282. {
  283. return false;
  284. }
  285. }
  286. return true;
  287. }
  288. bool JointsTrajectoryComponent::CheckIfVelocityReachedTolerance()
  289. {
  290. const auto& goalJointNames = m_trajectoryGoal.trajectory.joint_names;
  291. for (int jointIndex = 0; jointIndex < m_trajectoryGoal.trajectory.joint_names.size(); jointIndex++)
  292. { // Check if each joint velocity is below the threshold
  293. const AZStd::string_view jointName(goalJointNames[jointIndex].c_str());
  294. AZ::Outcome<float, AZStd::string> currentJointVelocity;
  295. JointsManipulationRequestBus::EventResult(
  296. currentJointVelocity, GetEntityId(), &JointsManipulationRequests::GetJointVelocity, jointName);
  297. if (!currentJointVelocity.IsSuccess())
  298. { // If velocity cannot be obtained, report failure
  299. return false;
  300. }
  301. if (!AZ::IsClose(currentJointVelocity.GetValue(), 0.0f, m_jointVelocityTolerance))
  302. {
  303. return false;
  304. }
  305. }
  306. return true;
  307. }
  308. void JointsTrajectoryComponent::MoveToNextPoint(const trajectory_msgs::msg::JointTrajectoryPoint currentTrajectoryPoint)
  309. {
  310. for (int jointIndex = 0; jointIndex < m_trajectoryGoal.trajectory.joint_names.size(); jointIndex++)
  311. { // Order each joint to be moved
  312. AZStd::string jointName(m_trajectoryGoal.trajectory.joint_names[jointIndex].c_str());
  313. AZ_Assert(m_manipulationJoints.find(jointName) != m_manipulationJoints.end(), "Invalid trajectory executing");
  314. float targetPos = currentTrajectoryPoint.positions[jointIndex];
  315. AZ::Outcome<void, AZStd::string> result;
  316. JointsManipulationRequestBus::EventResult(
  317. result, GetEntityId(), &JointsManipulationRequests::MoveJointToPosition, jointName, targetPos);
  318. AZ_Warning("JointTrajectoryComponent", result, "Joint move cannot be realized: %s", result.GetError().c_str());
  319. }
  320. }
  321. void JointsTrajectoryComponent::OnTick([[maybe_unused]] float deltaTime, [[maybe_unused]] AZ::ScriptTimePoint time)
  322. {
  323. if (m_manipulationJoints.empty())
  324. {
  325. GetManipulationJoints();
  326. return;
  327. }
  328. builtin_interfaces::msg::Time simTimestamp;
  329. ROS2::ROS2ClockRequestBus::BroadcastResult(simTimestamp, &ROS2::ROS2ClockRequestBus::Events::GetROSTimestamp);
  330. const float deltaSimulatedTime = ROS2::ROS2Conversions::GetTimeDifference(simTimestamp, m_lastTickTimestamp);
  331. m_lastTickTimestamp = simTimestamp;
  332. const uint64_t deltaTimeNs = deltaSimulatedTime * 1'000'000'000;
  333. FollowTrajectory(deltaTimeNs);
  334. UpdateFeedback();
  335. }
  336. } // namespace ROS2Controllers