| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386 |
- /*
- * Copyright (c) Contributors to the Open 3D Engine Project.
- * For complete copyright and license terms please see the LICENSE at the root of this distribution.
- *
- * SPDX-License-Identifier: Apache-2.0 OR MIT
- *
- */
- #include "JointsTrajectoryComponent.h"
- #include <AzCore/Outcome/Outcome.h>
- #include <AzCore/Serialization/EditContext.h>
- #include <PhysX/ArticulationJointBus.h>
- #include <ROS2/Clock/ROS2ClockRequestBus.h>
- #include <ROS2/Frame/ROS2FrameComponent.h>
- #include <ROS2/ROS2Bus.h>
- #include <ROS2/ROS2NamesBus.h>
- #include <ROS2/Utilities/ROS2Conversions.h>
- #include <ROS2Controllers/Manipulation/JointsManipulationRequests.h>
- namespace ROS2Controllers
- {
- JointsTrajectoryComponent::JointsTrajectoryComponent(const AZStd::string& followTrajectoryActionName)
- : m_followTrajectoryActionName(followTrajectoryActionName)
- {
- }
- void JointsTrajectoryComponent::Activate()
- {
- auto* ros2Frame = GetEntity()->FindComponent<ROS2::ROS2FrameComponent>();
- AZ_Assert(ros2Frame, "Missing Frame Component!");
- AZStd::string namespacedAction;
- ROS2::ROS2NamesRequestBus::BroadcastResult(
- namespacedAction,
- &ROS2::ROS2NamesRequestBus::Events::GetNamespacedName,
- ros2Frame->GetNamespace(),
- m_followTrajectoryActionName);
- m_followTrajectoryServer = AZStd::make_unique<FollowJointTrajectoryActionServer>(namespacedAction, GetEntityId());
- AZ::TickBus::Handler::BusConnect();
- JointsTrajectoryRequestBus::Handler::BusConnect(GetEntityId());
- ROS2::ROS2ClockRequestBus::BroadcastResult(m_lastTickTimestamp, &ROS2::ROS2ClockRequestBus::Events::GetROSTimestamp);
- }
- ManipulationJoints& JointsTrajectoryComponent::GetManipulationJoints()
- {
- if (m_manipulationJoints.empty())
- {
- JointsManipulationRequestBus::EventResult(m_manipulationJoints, GetEntityId(), &JointsManipulationRequests::GetJoints);
- }
- return m_manipulationJoints;
- }
- void JointsTrajectoryComponent::Deactivate()
- {
- JointsTrajectoryRequestBus::Handler::BusDisconnect();
- AZ::TickBus::Handler::BusDisconnect();
- m_followTrajectoryServer.reset();
- }
- void JointsTrajectoryComponent::Reflect(AZ::ReflectContext* context)
- {
- if (AZ::SerializeContext* serialize = azrtti_cast<AZ::SerializeContext*>(context))
- {
- serialize->Class<JointsTrajectoryComponent, AZ::Component>()->Version(1)
- ->Field("Action name", &JointsTrajectoryComponent::m_followTrajectoryActionName)
- ->Field("Check for position errors", &JointsTrajectoryComponent::m_checkForPositionErrors)
- ->Field("Joint goal tolerance", &JointsTrajectoryComponent::m_jointPositionTolerance)
- ->Field("Check for velocity", &JointsTrajectoryComponent::m_checkForVelocity)
- ->Field("Joint velocity tolerance", &JointsTrajectoryComponent::m_jointVelocityTolerance);
- if (AZ::EditContext* ec = serialize->GetEditContext())
- {
- ec->Class<JointsTrajectoryComponent>("JointsTrajectoryComponent", "Component to control a robotic arm using trajectories")
- ->ClassElement(AZ::Edit::ClassElements::EditorData, "")
- ->Attribute(AZ::Edit::Attributes::AppearsInAddComponentMenu, AZ_CRC("Game"))
- ->Attribute(AZ::Edit::Attributes::Category, "ROS2")
- ->Attribute(AZ::Edit::Attributes::Icon, "Editor/Icons/Components/JointsTrajectoryComponent.svg")
- ->Attribute(AZ::Edit::Attributes::ViewportIcon, "Editor/Icons/Components/Viewport/JointsTrajectoryComponent.svg")
- ->DataElement(
- AZ::Edit::UIHandlers::Default,
- &JointsTrajectoryComponent::m_followTrajectoryActionName,
- "Action Name",
- "Name the follow trajectory action server to accept movement commands")
- ->DataElement(
- AZ::Edit::UIHandlers::Default,
- &JointsTrajectoryComponent::m_checkForPositionErrors,
- "Check for Position Errors",
- "If true, check if joints reached the goal position before reporting success")
- ->Attribute(AZ::Edit::Attributes::ChangeNotify, AZ::Edit::PropertyRefreshLevels::AttributesAndValues)
- ->DataElement(
- AZ::Edit::UIHandlers::Default,
- &JointsTrajectoryComponent::m_jointPositionTolerance,
- "Joint Position Tolerance",
- "The threshold for joint position errors to report the goal as reached (one value for all joints, units depend on joint type)")
- ->Attribute(AZ::Edit::Attributes::Visibility, &JointsTrajectoryComponent::ShouldCheckForPositionErrors)
- ->DataElement(
- AZ::Edit::UIHandlers::Default,
- &JointsTrajectoryComponent::m_checkForVelocity,
- "Check for Velocity",
- "If true, check if joints velocity is below threshold before reporting success")
- ->Attribute(AZ::Edit::Attributes::ChangeNotify, AZ::Edit::PropertyRefreshLevels::AttributesAndValues)
- ->DataElement(
- AZ::Edit::UIHandlers::Default,
- &JointsTrajectoryComponent::m_jointVelocityTolerance,
- "Joint Velocity Tolerance",
- "The threshold for joint velocities under which to report the goal as reached (one value for all joints, units depend on joint type)"
- )
- ->Attribute(AZ::Edit::Attributes::Visibility, &JointsTrajectoryComponent::ShouldCheckForVelocity);
- }
- }
- }
- void JointsTrajectoryComponent::GetRequiredServices(AZ::ComponentDescriptor::DependencyArrayType& required)
- {
- required.push_back(AZ_CRC_CE("ROS2Frame"));
- required.push_back(AZ_CRC_CE("JointsManipulationService"));
- }
- void JointsTrajectoryComponent::GetProvidedServices(AZ::ComponentDescriptor::DependencyArrayType& provided)
- {
- provided.push_back(AZ_CRC_CE("ManipulatorJointTrajectoryService"));
- }
- void JointsTrajectoryComponent::GetIncompatibleServices(AZ::ComponentDescriptor::DependencyArrayType& incompatible)
- {
- incompatible.push_back(AZ_CRC_CE("ManipulatorJointTrajectoryService"));
- }
- bool JointsTrajectoryComponent::ShouldCheckForPositionErrors()
- {
- return m_checkForPositionErrors;
- }
- bool JointsTrajectoryComponent::ShouldCheckForVelocity()
- {
- return m_checkForVelocity;
- }
- AZ::Outcome<void, JointsTrajectoryComponent::TrajectoryResult> JointsTrajectoryComponent::StartTrajectoryGoal(
- TrajectoryGoalPtr trajectoryGoal)
- {
- if (m_goalStatus == JointsTrajectoryRequests::TrajectoryActionStatus::Executing)
- {
- auto result = JointsTrajectoryComponent::TrajectoryResult();
- result.error_code = JointsTrajectoryComponent::TrajectoryResult::INVALID_GOAL;
- result.error_string = "Another trajectory goal is executing. Wait for completion or cancel it";
- return AZ::Failure(result);
- }
- auto validationResult = ValidateGoal(trajectoryGoal);
- if (!validationResult)
- {
- return validationResult;
- }
- m_trajectoryGoal = *trajectoryGoal;
- ROS2::ROS2ClockRequestBus::BroadcastResult(m_trajectoryExecutionStartTime, &ROS2::ROS2ClockRequestBus::Events::GetROSTimestamp);
- m_goalStatus = JointsTrajectoryRequests::TrajectoryActionStatus::Executing;
- return AZ::Success();
- }
- AZ::Outcome<void, JointsTrajectoryComponent::TrajectoryResult> JointsTrajectoryComponent::ValidateGoal(TrajectoryGoalPtr trajectoryGoal)
- {
- // Check joint names validity
- for (const auto& jointName : trajectoryGoal->trajectory.joint_names)
- {
- AZStd::string azJointName(jointName.c_str());
- if (m_manipulationJoints.find(azJointName) == m_manipulationJoints.end())
- {
- AZ_Printf("JointsTrajectoryComponent", "Trajectory goal is invalid: no joint %s in manipulator", azJointName.c_str());
- auto result = JointsTrajectoryComponent::TrajectoryResult();
- result.error_code = JointsTrajectoryComponent::TrajectoryResult::INVALID_JOINTS;
- result.error_string = std::string(
- AZStd::string::format("Trajectory goal is invalid: no joint %s in manipulator", azJointName.c_str()).c_str());
- return AZ::Failure(result);
- }
- }
- return AZ::Success();
- }
- void JointsTrajectoryComponent::UpdateFeedback()
- {
- auto goalStatus = GetGoalStatus();
- if (goalStatus != JointsTrajectoryRequests::TrajectoryActionStatus::Executing)
- {
- return;
- }
- auto feedback = std::make_shared<control_msgs::action::FollowJointTrajectory::Feedback>();
- trajectory_msgs::msg::JointTrajectoryPoint desiredPoint = m_trajectoryGoal.trajectory.points.front();
- trajectory_msgs::msg::JointTrajectoryPoint actualPoint;
- size_t jointCount = m_trajectoryGoal.trajectory.joint_names.size();
- for (size_t jointIndex = 0; jointIndex < jointCount; jointIndex++)
- {
- AZStd::string jointName(m_trajectoryGoal.trajectory.joint_names[jointIndex].c_str());
- std::string jointNameStdString(jointName.c_str());
- feedback->joint_names.push_back(jointNameStdString);
- float currentJointPosition;
- float currentJointVelocity;
- auto& jointInfo = m_manipulationJoints[jointName];
- PhysX::ArticulationJointRequestBus::Event(
- jointInfo.m_entityComponentIdPair.GetEntityId(),
- [&](PhysX::ArticulationJointRequests* articulationJointRequests)
- {
- currentJointPosition = articulationJointRequests->GetJointPosition(jointInfo.m_axis);
- currentJointVelocity = articulationJointRequests->GetJointVelocity(jointInfo.m_axis);
- });
- actualPoint.positions.push_back(static_cast<double>(currentJointPosition));
- actualPoint.velocities.push_back(static_cast<double>(currentJointVelocity));
- // Acceleration should also be filled in somehow, or removed from the trajectory altogether.
- }
- trajectory_msgs::msg::JointTrajectoryPoint currentError;
- for (size_t jointIndex = 0; jointIndex < jointCount; jointIndex++)
- {
- currentError.positions.push_back(actualPoint.positions[jointIndex] - desiredPoint.positions[jointIndex]);
- currentError.velocities.push_back(actualPoint.velocities[jointIndex] - desiredPoint.velocities[jointIndex]);
- }
- feedback->desired = desiredPoint;
- feedback->actual = actualPoint;
- feedback->error = currentError;
- m_followTrajectoryServer->PublishFeedback(feedback);
- }
- AZ::Outcome<void, AZStd::string> JointsTrajectoryComponent::CancelTrajectoryGoal()
- {
- m_trajectoryGoal.trajectory.points.clear();
- m_goalStatus = JointsTrajectoryRequests::TrajectoryActionStatus::Cancelled;
- return AZ::Success();
- }
- JointsTrajectoryRequests::TrajectoryActionStatus JointsTrajectoryComponent::GetGoalStatus()
- {
- return m_goalStatus;
- }
- void JointsTrajectoryComponent::FollowTrajectory(const uint64_t deltaTimeNs)
- {
- auto goalStatus = GetGoalStatus();
- if (goalStatus == JointsTrajectoryRequests::TrajectoryActionStatus::Cancelled)
- {
- JointsManipulationRequestBus::Event(GetEntityId(), &JointsManipulationRequests::Stop);
- auto result = std::make_shared<FollowJointTrajectoryActionServer::FollowJointTrajectory::Result>();
- result->error_string = "User Cancelled";
- result->error_code = FollowJointTrajectoryActionServer::FollowJointTrajectory::Result::SUCCESSFUL;
- m_followTrajectoryServer->CancelGoal(result);
- return;
- }
- if (goalStatus != JointsTrajectoryRequests::TrajectoryActionStatus::Executing)
- {
- return;
- }
- if (m_trajectoryGoal.trajectory.points.size() == 0)
- { // The manipulator has reached the goal.
- AZ_TracePrintf("JointsManipulationComponent", "Goal Concluded: all points reached\n");
- auto successResult = std::make_shared<control_msgs::action::FollowJointTrajectory::Result>(); //!< Empty defaults to success.
- m_followTrajectoryServer->GoalSuccess(successResult);
- m_goalStatus = JointsTrajectoryRequests::TrajectoryActionStatus::Succeeded;
- return;
- }
- auto desiredGoal = m_trajectoryGoal.trajectory.points.front();
- rclcpp::Duration targetGoalTime = rclcpp::Duration(desiredGoal.time_from_start); //!< Requested arrival time for trajectory point.
- builtin_interfaces::msg::Time timestamp;
- ROS2::ROS2ClockRequestBus::BroadcastResult(timestamp, &ROS2::ROS2ClockRequestBus::Events::GetROSTimestamp);
- const rclcpp::Time timeNow = rclcpp::Time(timestamp); //!< Current simulation time.
- rclcpp::Duration threshold = rclcpp::Duration::from_nanoseconds(1e7);
- // Jump to the next point if current simulation time is ahead of timeFromStart
- bool canJumpToNextPoint = m_trajectoryExecutionStartTime + targetGoalTime <= timeNow + threshold;
- // But if it's the last point, wait until the manipulator reaches it
- bool lastTrajectoryPoint = m_trajectoryGoal.trajectory.points.size() == 1;
- if (lastTrajectoryPoint)
- {
- if (m_checkForPositionErrors)
- {
- canJumpToNextPoint &= CheckIfPositionReachedTolerance(desiredGoal);
- }
- if (m_checkForVelocity)
- {
- canJumpToNextPoint &= CheckIfVelocityReachedTolerance();
- }
- }
- if (canJumpToNextPoint)
- { // Jump to the next point in the trajectory
- m_trajectoryGoal.trajectory.points.erase(m_trajectoryGoal.trajectory.points.begin());
- FollowTrajectory(deltaTimeNs);
- return;
- }
- MoveToNextPoint(desiredGoal);
- }
- bool JointsTrajectoryComponent::CheckIfPositionReachedTolerance(const trajectory_msgs::msg::JointTrajectoryPoint trajectoryPoint)
- {
- const auto& goalJointNames = m_trajectoryGoal.trajectory.joint_names;
- for (int jointIndex = 0; jointIndex < m_trajectoryGoal.trajectory.joint_names.size(); jointIndex++)
- { // Check if each joint reached its target position
- const AZStd::string_view jointName(goalJointNames[jointIndex].c_str());
- AZ::Outcome<float, AZStd::string> currentJointPosition;
- JointsManipulationRequestBus::EventResult(
- currentJointPosition, GetEntityId(), &JointsManipulationRequests::GetJointPosition, jointName);
-
- if (!currentJointPosition.IsSuccess())
- { // If position cannot be obtained, report failure
- return false;
- }
- const float targetPos = trajectoryPoint.positions[jointIndex];
- if (!AZ::IsClose(currentJointPosition.GetValue(), targetPos, m_jointPositionTolerance))
- {
- return false;
- }
- }
- return true;
- }
- bool JointsTrajectoryComponent::CheckIfVelocityReachedTolerance()
- {
- const auto& goalJointNames = m_trajectoryGoal.trajectory.joint_names;
- for (int jointIndex = 0; jointIndex < m_trajectoryGoal.trajectory.joint_names.size(); jointIndex++)
- { // Check if each joint velocity is below the threshold
- const AZStd::string_view jointName(goalJointNames[jointIndex].c_str());
- AZ::Outcome<float, AZStd::string> currentJointVelocity;
- JointsManipulationRequestBus::EventResult(
- currentJointVelocity, GetEntityId(), &JointsManipulationRequests::GetJointVelocity, jointName);
- if (!currentJointVelocity.IsSuccess())
- { // If velocity cannot be obtained, report failure
- return false;
- }
- if (!AZ::IsClose(currentJointVelocity.GetValue(), 0.0f, m_jointVelocityTolerance))
- {
- return false;
- }
- }
- return true;
- }
- void JointsTrajectoryComponent::MoveToNextPoint(const trajectory_msgs::msg::JointTrajectoryPoint currentTrajectoryPoint)
- {
- for (int jointIndex = 0; jointIndex < m_trajectoryGoal.trajectory.joint_names.size(); jointIndex++)
- { // Order each joint to be moved
- AZStd::string jointName(m_trajectoryGoal.trajectory.joint_names[jointIndex].c_str());
- AZ_Assert(m_manipulationJoints.find(jointName) != m_manipulationJoints.end(), "Invalid trajectory executing");
- float targetPos = currentTrajectoryPoint.positions[jointIndex];
- AZ::Outcome<void, AZStd::string> result;
- JointsManipulationRequestBus::EventResult(
- result, GetEntityId(), &JointsManipulationRequests::MoveJointToPosition, jointName, targetPos);
- AZ_Warning("JointTrajectoryComponent", result, "Joint move cannot be realized: %s", result.GetError().c_str());
- }
- }
- void JointsTrajectoryComponent::OnTick([[maybe_unused]] float deltaTime, [[maybe_unused]] AZ::ScriptTimePoint time)
- {
- if (m_manipulationJoints.empty())
- {
- GetManipulationJoints();
- return;
- }
- builtin_interfaces::msg::Time simTimestamp;
- ROS2::ROS2ClockRequestBus::BroadcastResult(simTimestamp, &ROS2::ROS2ClockRequestBus::Events::GetROSTimestamp);
- const float deltaSimulatedTime = ROS2::ROS2Conversions::GetTimeDifference(simTimestamp, m_lastTickTimestamp);
- m_lastTickTimestamp = simTimestamp;
- const uint64_t deltaTimeNs = deltaSimulatedTime * 1'000'000'000;
- FollowTrajectory(deltaTimeNs);
- UpdateFeedback();
- }
- } // namespace ROS2Controllers
|