/* * Copyright (c) Contributors to the Open 3D Engine Project. * For complete copyright and license terms please see the LICENSE at the root of this distribution. * * SPDX-License-Identifier: Apache-2.0 OR MIT * */ #include "JointsTrajectoryComponent.h" #include #include #include #include #include #include #include #include #include namespace ROS2Controllers { JointsTrajectoryComponent::JointsTrajectoryComponent(const AZStd::string& followTrajectoryActionName) : m_followTrajectoryActionName(followTrajectoryActionName) { } void JointsTrajectoryComponent::Activate() { auto* ros2Frame = GetEntity()->FindComponent(); AZ_Assert(ros2Frame, "Missing Frame Component!"); AZStd::string namespacedAction; ROS2::ROS2NamesRequestBus::BroadcastResult( namespacedAction, &ROS2::ROS2NamesRequestBus::Events::GetNamespacedName, ros2Frame->GetNamespace(), m_followTrajectoryActionName); m_followTrajectoryServer = AZStd::make_unique(namespacedAction, GetEntityId()); AZ::TickBus::Handler::BusConnect(); JointsTrajectoryRequestBus::Handler::BusConnect(GetEntityId()); ROS2::ROS2ClockRequestBus::BroadcastResult(m_lastTickTimestamp, &ROS2::ROS2ClockRequestBus::Events::GetROSTimestamp); } ManipulationJoints& JointsTrajectoryComponent::GetManipulationJoints() { if (m_manipulationJoints.empty()) { JointsManipulationRequestBus::EventResult(m_manipulationJoints, GetEntityId(), &JointsManipulationRequests::GetJoints); } return m_manipulationJoints; } void JointsTrajectoryComponent::Deactivate() { JointsTrajectoryRequestBus::Handler::BusDisconnect(); AZ::TickBus::Handler::BusDisconnect(); m_followTrajectoryServer.reset(); } void JointsTrajectoryComponent::Reflect(AZ::ReflectContext* context) { if (AZ::SerializeContext* serialize = azrtti_cast(context)) { serialize->Class()->Version(1) ->Field("Action name", &JointsTrajectoryComponent::m_followTrajectoryActionName) ->Field("Check for position errors", &JointsTrajectoryComponent::m_checkForPositionErrors) ->Field("Joint goal tolerance", &JointsTrajectoryComponent::m_jointPositionTolerance) ->Field("Check for velocity", &JointsTrajectoryComponent::m_checkForVelocity) ->Field("Joint velocity tolerance", &JointsTrajectoryComponent::m_jointVelocityTolerance); if (AZ::EditContext* ec = serialize->GetEditContext()) { ec->Class("JointsTrajectoryComponent", "Component to control a robotic arm using trajectories") ->ClassElement(AZ::Edit::ClassElements::EditorData, "") ->Attribute(AZ::Edit::Attributes::AppearsInAddComponentMenu, AZ_CRC("Game")) ->Attribute(AZ::Edit::Attributes::Category, "ROS2") ->Attribute(AZ::Edit::Attributes::Icon, "Editor/Icons/Components/JointsTrajectoryComponent.svg") ->Attribute(AZ::Edit::Attributes::ViewportIcon, "Editor/Icons/Components/Viewport/JointsTrajectoryComponent.svg") ->DataElement( AZ::Edit::UIHandlers::Default, &JointsTrajectoryComponent::m_followTrajectoryActionName, "Action Name", "Name the follow trajectory action server to accept movement commands") ->DataElement( AZ::Edit::UIHandlers::Default, &JointsTrajectoryComponent::m_checkForPositionErrors, "Check for Position Errors", "If true, check if joints reached the goal position before reporting success") ->Attribute(AZ::Edit::Attributes::ChangeNotify, AZ::Edit::PropertyRefreshLevels::AttributesAndValues) ->DataElement( AZ::Edit::UIHandlers::Default, &JointsTrajectoryComponent::m_jointPositionTolerance, "Joint Position Tolerance", "The threshold for joint position errors to report the goal as reached (one value for all joints, units depend on joint type)") ->Attribute(AZ::Edit::Attributes::Visibility, &JointsTrajectoryComponent::ShouldCheckForPositionErrors) ->DataElement( AZ::Edit::UIHandlers::Default, &JointsTrajectoryComponent::m_checkForVelocity, "Check for Velocity", "If true, check if joints velocity is below threshold before reporting success") ->Attribute(AZ::Edit::Attributes::ChangeNotify, AZ::Edit::PropertyRefreshLevels::AttributesAndValues) ->DataElement( AZ::Edit::UIHandlers::Default, &JointsTrajectoryComponent::m_jointVelocityTolerance, "Joint Velocity Tolerance", "The threshold for joint velocities under which to report the goal as reached (one value for all joints, units depend on joint type)" ) ->Attribute(AZ::Edit::Attributes::Visibility, &JointsTrajectoryComponent::ShouldCheckForVelocity); } } } void JointsTrajectoryComponent::GetRequiredServices(AZ::ComponentDescriptor::DependencyArrayType& required) { required.push_back(AZ_CRC_CE("ROS2Frame")); required.push_back(AZ_CRC_CE("JointsManipulationService")); } void JointsTrajectoryComponent::GetProvidedServices(AZ::ComponentDescriptor::DependencyArrayType& provided) { provided.push_back(AZ_CRC_CE("ManipulatorJointTrajectoryService")); } void JointsTrajectoryComponent::GetIncompatibleServices(AZ::ComponentDescriptor::DependencyArrayType& incompatible) { incompatible.push_back(AZ_CRC_CE("ManipulatorJointTrajectoryService")); } bool JointsTrajectoryComponent::ShouldCheckForPositionErrors() { return m_checkForPositionErrors; } bool JointsTrajectoryComponent::ShouldCheckForVelocity() { return m_checkForVelocity; } AZ::Outcome JointsTrajectoryComponent::StartTrajectoryGoal( TrajectoryGoalPtr trajectoryGoal) { if (m_goalStatus == JointsTrajectoryRequests::TrajectoryActionStatus::Executing) { auto result = JointsTrajectoryComponent::TrajectoryResult(); result.error_code = JointsTrajectoryComponent::TrajectoryResult::INVALID_GOAL; result.error_string = "Another trajectory goal is executing. Wait for completion or cancel it"; return AZ::Failure(result); } auto validationResult = ValidateGoal(trajectoryGoal); if (!validationResult) { return validationResult; } m_trajectoryGoal = *trajectoryGoal; ROS2::ROS2ClockRequestBus::BroadcastResult(m_trajectoryExecutionStartTime, &ROS2::ROS2ClockRequestBus::Events::GetROSTimestamp); m_goalStatus = JointsTrajectoryRequests::TrajectoryActionStatus::Executing; return AZ::Success(); } AZ::Outcome JointsTrajectoryComponent::ValidateGoal(TrajectoryGoalPtr trajectoryGoal) { // Check joint names validity for (const auto& jointName : trajectoryGoal->trajectory.joint_names) { AZStd::string azJointName(jointName.c_str()); if (m_manipulationJoints.find(azJointName) == m_manipulationJoints.end()) { AZ_Printf("JointsTrajectoryComponent", "Trajectory goal is invalid: no joint %s in manipulator", azJointName.c_str()); auto result = JointsTrajectoryComponent::TrajectoryResult(); result.error_code = JointsTrajectoryComponent::TrajectoryResult::INVALID_JOINTS; result.error_string = std::string( AZStd::string::format("Trajectory goal is invalid: no joint %s in manipulator", azJointName.c_str()).c_str()); return AZ::Failure(result); } } return AZ::Success(); } void JointsTrajectoryComponent::UpdateFeedback() { auto goalStatus = GetGoalStatus(); if (goalStatus != JointsTrajectoryRequests::TrajectoryActionStatus::Executing) { return; } auto feedback = std::make_shared(); trajectory_msgs::msg::JointTrajectoryPoint desiredPoint = m_trajectoryGoal.trajectory.points.front(); trajectory_msgs::msg::JointTrajectoryPoint actualPoint; size_t jointCount = m_trajectoryGoal.trajectory.joint_names.size(); for (size_t jointIndex = 0; jointIndex < jointCount; jointIndex++) { AZStd::string jointName(m_trajectoryGoal.trajectory.joint_names[jointIndex].c_str()); std::string jointNameStdString(jointName.c_str()); feedback->joint_names.push_back(jointNameStdString); float currentJointPosition; float currentJointVelocity; auto& jointInfo = m_manipulationJoints[jointName]; PhysX::ArticulationJointRequestBus::Event( jointInfo.m_entityComponentIdPair.GetEntityId(), [&](PhysX::ArticulationJointRequests* articulationJointRequests) { currentJointPosition = articulationJointRequests->GetJointPosition(jointInfo.m_axis); currentJointVelocity = articulationJointRequests->GetJointVelocity(jointInfo.m_axis); }); actualPoint.positions.push_back(static_cast(currentJointPosition)); actualPoint.velocities.push_back(static_cast(currentJointVelocity)); // Acceleration should also be filled in somehow, or removed from the trajectory altogether. } trajectory_msgs::msg::JointTrajectoryPoint currentError; for (size_t jointIndex = 0; jointIndex < jointCount; jointIndex++) { currentError.positions.push_back(actualPoint.positions[jointIndex] - desiredPoint.positions[jointIndex]); currentError.velocities.push_back(actualPoint.velocities[jointIndex] - desiredPoint.velocities[jointIndex]); } feedback->desired = desiredPoint; feedback->actual = actualPoint; feedback->error = currentError; m_followTrajectoryServer->PublishFeedback(feedback); } AZ::Outcome JointsTrajectoryComponent::CancelTrajectoryGoal() { m_trajectoryGoal.trajectory.points.clear(); m_goalStatus = JointsTrajectoryRequests::TrajectoryActionStatus::Cancelled; return AZ::Success(); } JointsTrajectoryRequests::TrajectoryActionStatus JointsTrajectoryComponent::GetGoalStatus() { return m_goalStatus; } void JointsTrajectoryComponent::FollowTrajectory(const uint64_t deltaTimeNs) { auto goalStatus = GetGoalStatus(); if (goalStatus == JointsTrajectoryRequests::TrajectoryActionStatus::Cancelled) { JointsManipulationRequestBus::Event(GetEntityId(), &JointsManipulationRequests::Stop); auto result = std::make_shared(); result->error_string = "User Cancelled"; result->error_code = FollowJointTrajectoryActionServer::FollowJointTrajectory::Result::SUCCESSFUL; m_followTrajectoryServer->CancelGoal(result); return; } if (goalStatus != JointsTrajectoryRequests::TrajectoryActionStatus::Executing) { return; } if (m_trajectoryGoal.trajectory.points.size() == 0) { // The manipulator has reached the goal. AZ_TracePrintf("JointsManipulationComponent", "Goal Concluded: all points reached\n"); auto successResult = std::make_shared(); //!< Empty defaults to success. m_followTrajectoryServer->GoalSuccess(successResult); m_goalStatus = JointsTrajectoryRequests::TrajectoryActionStatus::Succeeded; return; } auto desiredGoal = m_trajectoryGoal.trajectory.points.front(); rclcpp::Duration targetGoalTime = rclcpp::Duration(desiredGoal.time_from_start); //!< Requested arrival time for trajectory point. builtin_interfaces::msg::Time timestamp; ROS2::ROS2ClockRequestBus::BroadcastResult(timestamp, &ROS2::ROS2ClockRequestBus::Events::GetROSTimestamp); const rclcpp::Time timeNow = rclcpp::Time(timestamp); //!< Current simulation time. rclcpp::Duration threshold = rclcpp::Duration::from_nanoseconds(1e7); // Jump to the next point if current simulation time is ahead of timeFromStart bool canJumpToNextPoint = m_trajectoryExecutionStartTime + targetGoalTime <= timeNow + threshold; // But if it's the last point, wait until the manipulator reaches it bool lastTrajectoryPoint = m_trajectoryGoal.trajectory.points.size() == 1; if (lastTrajectoryPoint) { if (m_checkForPositionErrors) { canJumpToNextPoint &= CheckIfPositionReachedTolerance(desiredGoal); } if (m_checkForVelocity) { canJumpToNextPoint &= CheckIfVelocityReachedTolerance(); } } if (canJumpToNextPoint) { // Jump to the next point in the trajectory m_trajectoryGoal.trajectory.points.erase(m_trajectoryGoal.trajectory.points.begin()); FollowTrajectory(deltaTimeNs); return; } MoveToNextPoint(desiredGoal); } bool JointsTrajectoryComponent::CheckIfPositionReachedTolerance(const trajectory_msgs::msg::JointTrajectoryPoint trajectoryPoint) { const auto& goalJointNames = m_trajectoryGoal.trajectory.joint_names; for (int jointIndex = 0; jointIndex < m_trajectoryGoal.trajectory.joint_names.size(); jointIndex++) { // Check if each joint reached its target position const AZStd::string_view jointName(goalJointNames[jointIndex].c_str()); AZ::Outcome currentJointPosition; JointsManipulationRequestBus::EventResult( currentJointPosition, GetEntityId(), &JointsManipulationRequests::GetJointPosition, jointName); if (!currentJointPosition.IsSuccess()) { // If position cannot be obtained, report failure return false; } const float targetPos = trajectoryPoint.positions[jointIndex]; if (!AZ::IsClose(currentJointPosition.GetValue(), targetPos, m_jointPositionTolerance)) { return false; } } return true; } bool JointsTrajectoryComponent::CheckIfVelocityReachedTolerance() { const auto& goalJointNames = m_trajectoryGoal.trajectory.joint_names; for (int jointIndex = 0; jointIndex < m_trajectoryGoal.trajectory.joint_names.size(); jointIndex++) { // Check if each joint velocity is below the threshold const AZStd::string_view jointName(goalJointNames[jointIndex].c_str()); AZ::Outcome currentJointVelocity; JointsManipulationRequestBus::EventResult( currentJointVelocity, GetEntityId(), &JointsManipulationRequests::GetJointVelocity, jointName); if (!currentJointVelocity.IsSuccess()) { // If velocity cannot be obtained, report failure return false; } if (!AZ::IsClose(currentJointVelocity.GetValue(), 0.0f, m_jointVelocityTolerance)) { return false; } } return true; } void JointsTrajectoryComponent::MoveToNextPoint(const trajectory_msgs::msg::JointTrajectoryPoint currentTrajectoryPoint) { for (int jointIndex = 0; jointIndex < m_trajectoryGoal.trajectory.joint_names.size(); jointIndex++) { // Order each joint to be moved AZStd::string jointName(m_trajectoryGoal.trajectory.joint_names[jointIndex].c_str()); AZ_Assert(m_manipulationJoints.find(jointName) != m_manipulationJoints.end(), "Invalid trajectory executing"); float targetPos = currentTrajectoryPoint.positions[jointIndex]; AZ::Outcome result; JointsManipulationRequestBus::EventResult( result, GetEntityId(), &JointsManipulationRequests::MoveJointToPosition, jointName, targetPos); AZ_Warning("JointTrajectoryComponent", result, "Joint move cannot be realized: %s", result.GetError().c_str()); } } void JointsTrajectoryComponent::OnTick([[maybe_unused]] float deltaTime, [[maybe_unused]] AZ::ScriptTimePoint time) { if (m_manipulationJoints.empty()) { GetManipulationJoints(); return; } builtin_interfaces::msg::Time simTimestamp; ROS2::ROS2ClockRequestBus::BroadcastResult(simTimestamp, &ROS2::ROS2ClockRequestBus::Events::GetROSTimestamp); const float deltaSimulatedTime = ROS2::ROS2Conversions::GetTimeDifference(simTimestamp, m_lastTickTimestamp); m_lastTickTimestamp = simTimestamp; const uint64_t deltaTimeNs = deltaSimulatedTime * 1'000'000'000; FollowTrajectory(deltaTimeNs); UpdateFeedback(); } } // namespace ROS2Controllers