JointsTrajectoryComponent.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include "JointsTrajectoryComponent.h"
  9. #include <AzCore/Serialization/EditContext.h>
  10. #include <PhysX/ArticulationJointBus.h>
  11. #include <ROS2/Frame/ROS2FrameComponent.h>
  12. #include <ROS2/ROS2Bus.h>
  13. #include <ROS2/Utilities/ROS2Conversions.h>
  14. #include <ROS2/Utilities/ROS2Names.h>
  15. #include <ROS2Controllers/Manipulation/JointsManipulationRequests.h>
  16. namespace ROS2Controllers
  17. {
  18. JointsTrajectoryComponent::JointsTrajectoryComponent(const AZStd::string& followTrajectoryActionName)
  19. : m_followTrajectoryActionName(followTrajectoryActionName)
  20. {
  21. }
  22. void JointsTrajectoryComponent::Activate()
  23. {
  24. auto* ros2Frame = GetEntity()->FindComponent<ROS2::ROS2FrameComponent>();
  25. AZ_Assert(ros2Frame, "Missing Frame Component!");
  26. AZStd::string namespacedAction = ROS2::ROS2Names::GetNamespacedName(ros2Frame->GetNamespace(), m_followTrajectoryActionName);
  27. m_followTrajectoryServer = AZStd::make_unique<FollowJointTrajectoryActionServer>(namespacedAction, GetEntityId());
  28. AZ::TickBus::Handler::BusConnect();
  29. JointsTrajectoryRequestBus::Handler::BusConnect(GetEntityId());
  30. m_lastTickTimestamp = ROS2::ROS2Interface::Get()->GetROSTimestamp();
  31. }
  32. ManipulationJoints& JointsTrajectoryComponent::GetManipulationJoints()
  33. {
  34. if (m_manipulationJoints.empty())
  35. {
  36. JointsManipulationRequestBus::EventResult(m_manipulationJoints, GetEntityId(), &JointsManipulationRequests::GetJoints);
  37. }
  38. return m_manipulationJoints;
  39. }
  40. void JointsTrajectoryComponent::Deactivate()
  41. {
  42. JointsTrajectoryRequestBus::Handler::BusDisconnect();
  43. AZ::TickBus::Handler::BusDisconnect();
  44. m_followTrajectoryServer.reset();
  45. }
  46. void JointsTrajectoryComponent::Reflect(AZ::ReflectContext* context)
  47. {
  48. if (AZ::SerializeContext* serialize = azrtti_cast<AZ::SerializeContext*>(context))
  49. {
  50. serialize->Class<JointsTrajectoryComponent, AZ::Component>()->Version(0)->Field(
  51. "Action name", &JointsTrajectoryComponent::m_followTrajectoryActionName);
  52. if (AZ::EditContext* ec = serialize->GetEditContext())
  53. {
  54. ec->Class<JointsTrajectoryComponent>("JointsTrajectoryComponent", "Component to control a robotic arm using trajectories")
  55. ->ClassElement(AZ::Edit::ClassElements::EditorData, "")
  56. ->Attribute(AZ::Edit::Attributes::AppearsInAddComponentMenu, AZ_CRC("Game"))
  57. ->Attribute(AZ::Edit::Attributes::Category, "ROS2")
  58. ->Attribute(AZ::Edit::Attributes::Icon, "Editor/Icons/Components/JointsTrajectoryComponent.svg")
  59. ->Attribute(AZ::Edit::Attributes::ViewportIcon, "Editor/Icons/Components/Viewport/JointsTrajectoryComponent.svg")
  60. ->DataElement(
  61. AZ::Edit::UIHandlers::Default,
  62. &JointsTrajectoryComponent::m_followTrajectoryActionName,
  63. "Action Name",
  64. "Name the follow trajectory action server to accept movement commands");
  65. }
  66. }
  67. }
  68. void JointsTrajectoryComponent::GetRequiredServices(AZ::ComponentDescriptor::DependencyArrayType& required)
  69. {
  70. required.push_back(AZ_CRC_CE("ROS2Frame"));
  71. required.push_back(AZ_CRC_CE("JointsManipulationService"));
  72. }
  73. void JointsTrajectoryComponent::GetProvidedServices(AZ::ComponentDescriptor::DependencyArrayType& provided)
  74. {
  75. provided.push_back(AZ_CRC_CE("ManipulatorJointTrajectoryService"));
  76. }
  77. void JointsTrajectoryComponent::GetIncompatibleServices(AZ::ComponentDescriptor::DependencyArrayType& incompatible)
  78. {
  79. incompatible.push_back(AZ_CRC_CE("ManipulatorJointTrajectoryService"));
  80. }
  81. AZ::Outcome<void, JointsTrajectoryComponent::TrajectoryResult> JointsTrajectoryComponent::StartTrajectoryGoal(
  82. TrajectoryGoalPtr trajectoryGoal)
  83. {
  84. if (m_goalStatus == JointsTrajectoryRequests::TrajectoryActionStatus::Executing)
  85. {
  86. auto result = JointsTrajectoryComponent::TrajectoryResult();
  87. result.error_code = JointsTrajectoryComponent::TrajectoryResult::INVALID_GOAL;
  88. result.error_string = "Another trajectory goal is executing. Wait for completion or cancel it";
  89. return AZ::Failure(result);
  90. }
  91. auto validationResult = ValidateGoal(trajectoryGoal);
  92. if (!validationResult)
  93. {
  94. return validationResult;
  95. }
  96. m_trajectoryGoal = *trajectoryGoal;
  97. m_trajectoryExecutionStartTime = rclcpp::Time(ROS2::ROS2Interface::Get()->GetROSTimestamp());
  98. m_goalStatus = JointsTrajectoryRequests::TrajectoryActionStatus::Executing;
  99. return AZ::Success();
  100. }
  101. AZ::Outcome<void, JointsTrajectoryComponent::TrajectoryResult> JointsTrajectoryComponent::ValidateGoal(TrajectoryGoalPtr trajectoryGoal)
  102. {
  103. // Check joint names validity
  104. for (const auto& jointName : trajectoryGoal->trajectory.joint_names)
  105. {
  106. AZStd::string azJointName(jointName.c_str());
  107. if (m_manipulationJoints.find(azJointName) == m_manipulationJoints.end())
  108. {
  109. AZ_Printf("JointsTrajectoryComponent", "Trajectory goal is invalid: no joint %s in manipulator", azJointName.c_str());
  110. auto result = JointsTrajectoryComponent::TrajectoryResult();
  111. result.error_code = JointsTrajectoryComponent::TrajectoryResult::INVALID_JOINTS;
  112. result.error_string = std::string(
  113. AZStd::string::format("Trajectory goal is invalid: no joint %s in manipulator", azJointName.c_str()).c_str());
  114. return AZ::Failure(result);
  115. }
  116. }
  117. return AZ::Success();
  118. }
  119. void JointsTrajectoryComponent::UpdateFeedback()
  120. {
  121. auto goalStatus = GetGoalStatus();
  122. if (goalStatus != JointsTrajectoryRequests::TrajectoryActionStatus::Executing)
  123. {
  124. return;
  125. }
  126. auto feedback = std::make_shared<control_msgs::action::FollowJointTrajectory::Feedback>();
  127. trajectory_msgs::msg::JointTrajectoryPoint desiredPoint = m_trajectoryGoal.trajectory.points.front();
  128. trajectory_msgs::msg::JointTrajectoryPoint actualPoint;
  129. size_t jointCount = m_trajectoryGoal.trajectory.joint_names.size();
  130. for (size_t jointIndex = 0; jointIndex < jointCount; jointIndex++)
  131. {
  132. AZStd::string jointName(m_trajectoryGoal.trajectory.joint_names[jointIndex].c_str());
  133. std::string jointNameStdString(jointName.c_str());
  134. feedback->joint_names.push_back(jointNameStdString);
  135. float currentJointPosition;
  136. float currentJointVelocity;
  137. auto& jointInfo = m_manipulationJoints[jointName];
  138. PhysX::ArticulationJointRequestBus::Event(
  139. jointInfo.m_entityComponentIdPair.GetEntityId(),
  140. [&](PhysX::ArticulationJointRequests* articulationJointRequests)
  141. {
  142. currentJointPosition = articulationJointRequests->GetJointPosition(jointInfo.m_axis);
  143. currentJointVelocity = articulationJointRequests->GetJointVelocity(jointInfo.m_axis);
  144. });
  145. actualPoint.positions.push_back(static_cast<double>(currentJointPosition));
  146. actualPoint.velocities.push_back(static_cast<double>(currentJointVelocity));
  147. // Acceleration should also be filled in somehow, or removed from the trajectory altogether.
  148. }
  149. trajectory_msgs::msg::JointTrajectoryPoint currentError;
  150. for (size_t jointIndex = 0; jointIndex < jointCount; jointIndex++)
  151. {
  152. currentError.positions.push_back(actualPoint.positions[jointIndex] - desiredPoint.positions[jointIndex]);
  153. currentError.velocities.push_back(actualPoint.velocities[jointIndex] - desiredPoint.velocities[jointIndex]);
  154. }
  155. feedback->desired = desiredPoint;
  156. feedback->actual = actualPoint;
  157. feedback->error = currentError;
  158. m_followTrajectoryServer->PublishFeedback(feedback);
  159. }
  160. AZ::Outcome<void, AZStd::string> JointsTrajectoryComponent::CancelTrajectoryGoal()
  161. {
  162. m_trajectoryGoal.trajectory.points.clear();
  163. m_goalStatus = JointsTrajectoryRequests::TrajectoryActionStatus::Cancelled;
  164. return AZ::Success();
  165. }
  166. JointsTrajectoryRequests::TrajectoryActionStatus JointsTrajectoryComponent::GetGoalStatus()
  167. {
  168. return m_goalStatus;
  169. }
  170. void JointsTrajectoryComponent::FollowTrajectory(const uint64_t deltaTimeNs)
  171. {
  172. auto goalStatus = GetGoalStatus();
  173. if (goalStatus == JointsTrajectoryRequests::TrajectoryActionStatus::Cancelled)
  174. {
  175. JointsManipulationRequestBus::Event(GetEntityId(), &JointsManipulationRequests::Stop);
  176. auto result = std::make_shared<FollowJointTrajectoryActionServer::FollowJointTrajectory::Result>();
  177. result->error_string = "User Cancelled";
  178. result->error_code = FollowJointTrajectoryActionServer::FollowJointTrajectory::Result::SUCCESSFUL;
  179. m_followTrajectoryServer->CancelGoal(result);
  180. return;
  181. }
  182. if (goalStatus != JointsTrajectoryRequests::TrajectoryActionStatus::Executing)
  183. {
  184. return;
  185. }
  186. if (m_trajectoryGoal.trajectory.points.size() == 0)
  187. { // The manipulator has reached the goal.
  188. AZ_TracePrintf("JointsManipulationComponent", "Goal Concluded: all points reached\n");
  189. auto successResult = std::make_shared<control_msgs::action::FollowJointTrajectory::Result>(); //!< Empty defaults to success.
  190. m_followTrajectoryServer->GoalSuccess(successResult);
  191. m_goalStatus = JointsTrajectoryRequests::TrajectoryActionStatus::Succeeded;
  192. return;
  193. }
  194. auto desiredGoal = m_trajectoryGoal.trajectory.points.front();
  195. rclcpp::Duration targetGoalTime = rclcpp::Duration(desiredGoal.time_from_start); //!< Requested arrival time for trajectory point.
  196. rclcpp::Time timeNow = rclcpp::Time(ROS2::ROS2Interface::Get()->GetROSTimestamp()); //!< Current simulation time.
  197. rclcpp::Duration threshold = rclcpp::Duration::from_nanoseconds(1e7);
  198. if (m_trajectoryExecutionStartTime + targetGoalTime <= timeNow + threshold)
  199. { // Jump to the next point if current simulation time is ahead of timeFromStart
  200. m_trajectoryGoal.trajectory.points.erase(m_trajectoryGoal.trajectory.points.begin());
  201. FollowTrajectory(deltaTimeNs);
  202. return;
  203. }
  204. MoveToNextPoint(desiredGoal);
  205. }
  206. void JointsTrajectoryComponent::MoveToNextPoint(const trajectory_msgs::msg::JointTrajectoryPoint currentTrajectoryPoint)
  207. {
  208. for (int jointIndex = 0; jointIndex < m_trajectoryGoal.trajectory.joint_names.size(); jointIndex++)
  209. { // Order each joint to be moved
  210. AZStd::string jointName(m_trajectoryGoal.trajectory.joint_names[jointIndex].c_str());
  211. AZ_Assert(m_manipulationJoints.find(jointName) != m_manipulationJoints.end(), "Invalid trajectory executing");
  212. float targetPos = currentTrajectoryPoint.positions[jointIndex];
  213. AZ::Outcome<void, AZStd::string> result;
  214. JointsManipulationRequestBus::EventResult(
  215. result, GetEntityId(), &JointsManipulationRequests::MoveJointToPosition, jointName, targetPos);
  216. AZ_Warning("JointTrajectoryComponent", result, "Joint move cannot be realized: %s", result.GetError().c_str());
  217. }
  218. }
  219. void JointsTrajectoryComponent::OnTick([[maybe_unused]] float deltaTime, [[maybe_unused]] AZ::ScriptTimePoint time)
  220. {
  221. if (m_manipulationJoints.empty())
  222. {
  223. GetManipulationJoints();
  224. return;
  225. }
  226. const auto simTimestamp = ROS2::ROS2Interface::Get()->GetROSTimestamp();
  227. const float deltaSimulatedTime = ROS2::ROS2Conversions::GetTimeDifference(simTimestamp, m_lastTickTimestamp);
  228. m_lastTickTimestamp = simTimestamp;
  229. const uint64_t deltaTimeNs = deltaSimulatedTime * 1'000'000'000;
  230. FollowTrajectory(deltaTimeNs);
  231. UpdateFeedback();
  232. }
  233. } // namespace ROS2Controllers