JointsManipulationComponent.cpp 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include "JointsManipulationComponent.h"
  9. #include "Controllers/JointsArticulationControllerComponent.h"
  10. #include "Controllers/JointsPIDControllerComponent.h"
  11. #include "JointStatePublisher.h"
  12. #include "ManipulationUtils.h"
  13. #include <AzCore/Component/ComponentApplicationBus.h>
  14. #include <AzCore/Component/TransformBus.h>
  15. #include <AzCore/Debug/Trace.h>
  16. #include <AzCore/Serialization/EditContext.h>
  17. #include <ROS2/Frame/ROS2FrameComponent.h>
  18. #include <ROS2/Utilities/ROS2Conversions.h>
  19. #include <ROS2/Utilities/ROS2Names.h>
  20. #include <ROS2Controllers/Manipulation/Controllers/JointsPositionControllerRequests.h>
  21. #include <Source/ArticulationLinkComponent.h>
  22. #include <Source/HingeJointComponent.h>
  23. #include <Source/PrismaticJointComponent.h>
  24. #include <Utilities/ArticulationsUtilities.h>
  25. namespace ROS2Controllers
  26. {
  27. namespace Internal
  28. {
  29. void Add1DOFJointInfo(const AZ::EntityComponentIdPair& idPair, const AZStd::string& jointName, ManipulationJoints& joints)
  30. {
  31. if (joints.find(jointName) != joints.end())
  32. {
  33. AZ_Assert(false, "Joint names in hierarchy need to be unique (%s is not)!", jointName.c_str());
  34. return;
  35. }
  36. AZ_Printf("JointsManipulationComponent", "Adding joint info for hinge joint %s\n", jointName.c_str());
  37. JointInfo jointInfo;
  38. jointInfo.m_isArticulation = false;
  39. jointInfo.m_axis = static_cast<PhysX::ArticulationJointAxis>(0);
  40. jointInfo.m_entityComponentIdPair = idPair;
  41. joints[jointName] = jointInfo;
  42. }
  43. void AddArticulationJointInfo(const AZ::EntityComponentIdPair& idPair, const AZStd::string& jointName, ManipulationJoints& joints)
  44. {
  45. PhysX::ArticulationJointAxis freeAxis;
  46. bool hasFreeAxis = Utils::TryGetFreeArticulationAxis(idPair.GetEntityId(), freeAxis);
  47. if (!hasFreeAxis)
  48. { // Do not add a joint since it is a fixed one
  49. AZ_Printf("JointsManipulationComponent", "Articulation joint %s is fixed, skipping\n", jointName.c_str());
  50. return;
  51. }
  52. if (joints.find(jointName) != joints.end())
  53. {
  54. AZ_Assert(false, "Joint names in hierarchy need to be unique (%s is not)!", jointName.c_str());
  55. return;
  56. }
  57. AZ_Printf("JointsManipulationComponent", "Adding joint info for articulation link %s\n", jointName.c_str());
  58. JointInfo jointInfo;
  59. jointInfo.m_isArticulation = true;
  60. jointInfo.m_axis = freeAxis;
  61. jointInfo.m_entityComponentIdPair = idPair;
  62. joints[jointName] = jointInfo;
  63. }
  64. ManipulationJoints GetAllEntityHierarchyJoints(const AZ::EntityId& entityId)
  65. { // Look for either Articulation Links or Hinge joints in entity hierarchy and collect them into a map.
  66. // Determine kind of joints through presence of appropriate controller
  67. bool supportsArticulation = false;
  68. bool supportsClassicJoints = false;
  69. JointsPositionControllerRequestBus::EventResult(
  70. supportsArticulation, entityId, &JointsPositionControllerRequests::SupportsArticulation);
  71. JointsPositionControllerRequestBus::EventResult(
  72. supportsClassicJoints, entityId, &JointsPositionControllerRequests::SupportsClassicJoints);
  73. ManipulationJoints manipulationJoints;
  74. if (!supportsArticulation && !supportsClassicJoints)
  75. {
  76. AZ_Warning("JointsManipulationComponent", false, "No suitable Position Controller Component in entity!");
  77. return manipulationJoints;
  78. }
  79. if (supportsArticulation && supportsClassicJoints)
  80. {
  81. AZ_Warning("JointsManipulationComponent", false, "Cannot support both classic joint and articulations in one hierarchy");
  82. return manipulationJoints;
  83. }
  84. // Get all descendants and iterate over joints
  85. AZStd::vector<AZ::EntityId> descendants;
  86. AZ::TransformBus::EventResult(descendants, entityId, &AZ::TransformInterface::GetEntityAndAllDescendants);
  87. AZ_Warning("JointsManipulationComponent", descendants.size() > 0, "Entity %s has no descendants!", entityId.ToString().c_str());
  88. for (const AZ::EntityId& descendantID : descendants)
  89. {
  90. AZ::Entity* entity = nullptr;
  91. AZ::ComponentApplicationBus::BroadcastResult(entity, &AZ::ComponentApplicationRequests::FindEntity, descendantID);
  92. AZ_Assert(entity, "Unknown entity %s", descendantID.ToString().c_str());
  93. // If there is a Frame Component, take joint name stored in it.
  94. auto* frameComponent = entity->FindComponent<ROS2::ROS2FrameComponent>();
  95. if (!frameComponent)
  96. { // Frame Component is required for joints.
  97. continue;
  98. }
  99. const AZStd::string jointName(frameComponent->GetJointName().GetCStr());
  100. auto* hingeComponent = entity->FindComponent<PhysX::HingeJointComponent>();
  101. auto* prismaticComponent = entity->FindComponent<PhysX::PrismaticJointComponent>();
  102. auto* articulationComponent = entity->FindComponent<PhysX::ArticulationLinkComponent>();
  103. [[maybe_unused]] bool classicJoint = hingeComponent || prismaticComponent;
  104. AZ_Warning(
  105. "JointsManipulationComponent",
  106. (classicJoint && supportsClassicJoints) || !classicJoint,
  107. "Found classic joints but the controller does not support them!");
  108. AZ_Warning(
  109. "JointsManipulationComponent",
  110. (articulationComponent && supportsArticulation) || !articulationComponent,
  111. "Found articulations but the controller does not support them!");
  112. // See if there is a Hinge Joint in the entity, add it to map.
  113. if (supportsClassicJoints && hingeComponent)
  114. {
  115. auto idPair = AZ::EntityComponentIdPair(hingeComponent->GetEntityId(), hingeComponent->GetId());
  116. Internal::Add1DOFJointInfo(idPair, jointName, manipulationJoints);
  117. }
  118. // See if there is a Prismatic Joint in the entity, add it to map.
  119. if (supportsClassicJoints && prismaticComponent)
  120. {
  121. auto idPair = AZ::EntityComponentIdPair(prismaticComponent->GetEntityId(), prismaticComponent->GetId());
  122. Internal::Add1DOFJointInfo(idPair, jointName, manipulationJoints);
  123. }
  124. // See if there is an Articulation Link in the entity, add it to map.
  125. if (supportsArticulation && articulationComponent)
  126. {
  127. auto idPair = AZ::EntityComponentIdPair(articulationComponent->GetEntityId(), articulationComponent->GetId());
  128. Internal::AddArticulationJointInfo(idPair, jointName, manipulationJoints);
  129. }
  130. }
  131. return manipulationJoints;
  132. }
  133. void SetInitialPositions(ManipulationJoints& manipulationJoints, const AZStd::unordered_map<AZStd::string, float>& initialPositions)
  134. {
  135. // Set the initial / resting position to move to and keep.
  136. for (const auto& [jointName, jointInfo] : manipulationJoints)
  137. {
  138. if (initialPositions.contains(jointName))
  139. {
  140. manipulationJoints[jointName].m_restPosition = initialPositions.at(jointName);
  141. }
  142. else
  143. {
  144. AZ_Warning("JointsManipulationComponent", false, "No set initial position for joint %s", jointName.c_str());
  145. }
  146. }
  147. }
  148. } // namespace Internal
  149. JointsManipulationComponent::JointsManipulationComponent()
  150. {
  151. }
  152. JointsManipulationComponent::JointsManipulationComponent(
  153. const ROS2::PublisherConfiguration& publisherConfiguration,
  154. const AZStd::vector<AZStd::pair<AZStd::string, float>>& initialPositions)
  155. : m_jointStatePublisherConfiguration(publisherConfiguration)
  156. , m_initialPositions(initialPositions)
  157. {
  158. }
  159. void JointsManipulationComponent::Activate()
  160. {
  161. auto* ros2Frame = GetEntity()->FindComponent<ROS2::ROS2FrameComponent>();
  162. JointStatePublisherContext publisherContext;
  163. publisherContext.m_publisherNamespace = ros2Frame->GetNamespace();
  164. publisherContext.m_frameId = ros2Frame->GetFrameID();
  165. publisherContext.m_entityId = GetEntityId();
  166. m_jointStatePublisher = AZStd::make_unique<JointStatePublisher>(m_jointStatePublisherConfiguration, publisherContext);
  167. m_lastTickTimestamp = ROS2::ROS2Interface::Get()->GetROSTimestamp();
  168. AZ::TickBus::Handler::BusConnect();
  169. JointsManipulationRequestBus::Handler::BusConnect(GetEntityId());
  170. }
  171. void JointsManipulationComponent::Deactivate()
  172. {
  173. JointsManipulationRequestBus::Handler::BusDisconnect();
  174. AZ::TickBus::Handler::BusDisconnect();
  175. }
  176. ManipulationJoints JointsManipulationComponent::GetJoints()
  177. {
  178. return m_manipulationJoints;
  179. }
  180. AZ::Outcome<JointPosition, AZStd::string> JointsManipulationComponent::GetJointPosition(const JointInfo& jointInfo)
  181. {
  182. float position{ 0.f };
  183. if (jointInfo.m_isArticulation)
  184. {
  185. PhysX::ArticulationJointRequestBus::EventResult(
  186. position,
  187. jointInfo.m_entityComponentIdPair.GetEntityId(),
  188. &PhysX::ArticulationJointRequests::GetJointPosition,
  189. jointInfo.m_axis);
  190. }
  191. else
  192. {
  193. PhysX::JointRequestBus::EventResult(position, jointInfo.m_entityComponentIdPair, &PhysX::JointRequests::GetPosition);
  194. }
  195. return AZ::Success(position);
  196. }
  197. AZ::Outcome<JointPosition, AZStd::string> JointsManipulationComponent::GetJointPosition(const AZStd::string& jointName)
  198. {
  199. if (!m_manipulationJoints.contains(jointName))
  200. {
  201. return AZ::Failure(AZStd::string::format("Joint %s does not exist", jointName.c_str()));
  202. }
  203. auto jointInfo = m_manipulationJoints.at(jointName);
  204. return GetJointPosition(jointInfo);
  205. }
  206. AZ::Outcome<JointVelocity, AZStd::string> JointsManipulationComponent::GetJointVelocity(const JointInfo& jointInfo)
  207. {
  208. float velocity{ 0.f };
  209. if (jointInfo.m_isArticulation)
  210. {
  211. PhysX::ArticulationJointRequestBus::EventResult(
  212. velocity,
  213. jointInfo.m_entityComponentIdPair.GetEntityId(),
  214. &PhysX::ArticulationJointRequests::GetJointVelocity,
  215. jointInfo.m_axis);
  216. }
  217. else
  218. {
  219. PhysX::JointRequestBus::EventResult(velocity, jointInfo.m_entityComponentIdPair, &PhysX::JointRequests::GetVelocity);
  220. }
  221. return AZ::Success(velocity);
  222. }
  223. AZ::Outcome<JointVelocity, AZStd::string> JointsManipulationComponent::GetJointVelocity(const AZStd::string& jointName)
  224. {
  225. if (!m_manipulationJoints.contains(jointName))
  226. {
  227. return AZ::Failure(AZStd::string::format("Joint %s does not exist", jointName.c_str()));
  228. }
  229. auto jointInfo = m_manipulationJoints.at(jointName);
  230. return GetJointVelocity(jointInfo);
  231. }
  232. JointsManipulationRequests::JointsPositionsMap JointsManipulationComponent::GetAllJointsPositions()
  233. {
  234. JointsManipulationRequests::JointsPositionsMap positions;
  235. for (const auto& [jointName, jointInfo] : m_manipulationJoints)
  236. {
  237. positions[jointName] = GetJointPosition(jointInfo).GetValue();
  238. }
  239. return positions;
  240. }
  241. JointsManipulationRequests::JointsVelocitiesMap JointsManipulationComponent::GetAllJointsVelocities()
  242. {
  243. JointsManipulationRequests::JointsVelocitiesMap velocities;
  244. for (const auto& [jointName, jointInfo] : m_manipulationJoints)
  245. {
  246. velocities[jointName] = GetJointVelocity(jointInfo).GetValue();
  247. }
  248. return velocities;
  249. }
  250. AZ::Outcome<JointEffort, AZStd::string> JointsManipulationComponent::GetJointEffort(const JointInfo& jointInfo)
  251. {
  252. auto jointStateData = Utils::GetJointState(jointInfo);
  253. return AZ::Success(jointStateData.effort);
  254. }
  255. AZ::Outcome<JointEffort, AZStd::string> JointsManipulationComponent::GetJointEffort(const AZStd::string& jointName)
  256. {
  257. if (!m_manipulationJoints.contains(jointName))
  258. {
  259. return AZ::Failure(AZStd::string::format("Joint %s does not exist", jointName.c_str()));
  260. }
  261. auto jointInfo = m_manipulationJoints.at(jointName);
  262. return GetJointEffort(jointInfo);
  263. }
  264. JointsManipulationRequests::JointsEffortsMap JointsManipulationComponent::GetAllJointsEfforts()
  265. {
  266. JointsManipulationRequests::JointsEffortsMap efforts;
  267. for (const auto& [jointName, jointInfo] : m_manipulationJoints)
  268. {
  269. efforts[jointName] = GetJointEffort(jointInfo).GetValue();
  270. }
  271. return efforts;
  272. }
  273. AZ::Outcome<void, AZStd::string> JointsManipulationComponent::SetMaxJointEffort(const AZStd::string& jointName, JointEffort maxEffort)
  274. {
  275. if (!m_manipulationJoints.contains(jointName))
  276. {
  277. return AZ::Failure(AZStd::string::format("Joint %s does not exist", jointName.c_str()));
  278. }
  279. auto jointInfo = m_manipulationJoints.at(jointName);
  280. if (jointInfo.m_isArticulation)
  281. {
  282. PhysX::ArticulationJointRequestBus::Event(
  283. jointInfo.m_entityComponentIdPair.GetEntityId(),
  284. &PhysX::ArticulationJointRequests::SetMaxForce,
  285. jointInfo.m_axis,
  286. maxEffort);
  287. }
  288. return AZ::Success();
  289. }
  290. AZ::Outcome<void, AZStd::string> JointsManipulationComponent::MoveJointToPosition(
  291. const AZStd::string& jointName, JointPosition position)
  292. {
  293. if (!m_manipulationJoints.contains(jointName))
  294. {
  295. return AZ::Failure(AZStd::string::format("Joint %s does not exist", jointName.c_str()));
  296. }
  297. m_manipulationJoints[jointName].m_restPosition = position;
  298. return AZ::Success();
  299. }
  300. AZ::Outcome<void, AZStd::string> JointsManipulationComponent::MoveJointsToPositions(
  301. const JointsManipulationRequests::JointsPositionsMap& positions)
  302. {
  303. for (const auto& [jointName, position] : positions)
  304. {
  305. auto outcome = MoveJointToPosition(jointName, position);
  306. if (!outcome)
  307. {
  308. return outcome;
  309. }
  310. }
  311. return AZ::Success();
  312. }
  313. void JointsManipulationComponent::GetRequiredServices(AZ::ComponentDescriptor::DependencyArrayType& required)
  314. {
  315. required.push_back(AZ_CRC_CE("ROS2Frame"));
  316. required.push_back(AZ_CRC_CE("JointsControllerService"));
  317. }
  318. void JointsManipulationComponent::GetProvidedServices(AZ::ComponentDescriptor::DependencyArrayType& provided)
  319. {
  320. provided.push_back(AZ_CRC_CE("JointsManipulationService"));
  321. }
  322. void JointsManipulationComponent::GetIncompatibleServices(AZ::ComponentDescriptor::DependencyArrayType& incompatible)
  323. {
  324. incompatible.push_back(AZ_CRC_CE("JointsManipulationService"));
  325. }
  326. void JointsManipulationComponent::Reflect(AZ::ReflectContext* context)
  327. {
  328. if (AZ::SerializeContext* serialize = azrtti_cast<AZ::SerializeContext*>(context))
  329. {
  330. serialize->Class<JointsManipulationComponent, AZ::Component>()
  331. ->Version(2)
  332. ->Field("JointStatesPublisherConfiguration", &JointsManipulationComponent::m_jointStatePublisherConfiguration)
  333. ->Field("OrderedInitialJointPositions", &JointsManipulationComponent::m_initialPositions);
  334. }
  335. }
  336. void JointsManipulationComponent::MoveToSetPositions(float deltaTime)
  337. {
  338. for (const auto& [jointName, jointInfo] : m_manipulationJoints)
  339. {
  340. float currentPosition = GetJointPosition(jointName).GetValue();
  341. float desiredPosition = jointInfo.m_restPosition;
  342. AZ::Outcome<void, AZStd::string> positionControlOutcome;
  343. JointsPositionControllerRequestBus::EventResult(
  344. positionControlOutcome,
  345. GetEntityId(),
  346. &JointsPositionControllerRequests::PositionControl,
  347. jointName,
  348. jointInfo,
  349. currentPosition,
  350. desiredPosition,
  351. deltaTime);
  352. AZ_Warning(
  353. "JointsManipulationComponent",
  354. positionControlOutcome,
  355. "Position control failed for joint %s (%s): %s",
  356. jointName.c_str(),
  357. jointInfo.m_entityComponentIdPair.GetEntityId().ToString().c_str(),
  358. positionControlOutcome.GetError().c_str());
  359. }
  360. }
  361. void JointsManipulationComponent::Stop()
  362. {
  363. for (auto& [jointName, jointInfo] : m_manipulationJoints)
  364. { // Set all target joint positions to their current positions. There is no need to check if the outcome is successful, because
  365. // jointName is always valid.
  366. jointInfo.m_restPosition = GetJointPosition(jointName).GetValue();
  367. }
  368. }
  369. AZStd::string JointsManipulationComponent::GetManipulatorNamespace() const
  370. {
  371. auto* frameComponent = GetEntity()->FindComponent<ROS2::ROS2FrameComponent>();
  372. AZ_Assert(frameComponent, "ROS2FrameComponent is required for joints.");
  373. return frameComponent->GetNamespace();
  374. }
  375. void JointsManipulationComponent::OnTick([[maybe_unused]] float deltaTime, [[maybe_unused]] AZ::ScriptTimePoint time)
  376. {
  377. if (m_manipulationJoints.empty())
  378. {
  379. const AZStd::string manipulatorNamespace = GetManipulatorNamespace();
  380. AZStd::unordered_map<AZStd::string, JointPosition> initialPositionNamespaced;
  381. AZStd::transform(
  382. m_initialPositions.begin(),
  383. m_initialPositions.end(),
  384. AZStd::inserter(initialPositionNamespaced, initialPositionNamespaced.end()),
  385. [&manipulatorNamespace](const auto& pair)
  386. {
  387. return AZStd::make_pair(ROS2::ROS2Names::GetNamespacedName(manipulatorNamespace, pair.first), pair.second);
  388. });
  389. m_manipulationJoints = Internal::GetAllEntityHierarchyJoints(GetEntityId());
  390. Internal::SetInitialPositions(m_manipulationJoints, initialPositionNamespaced);
  391. if (m_manipulationJoints.empty())
  392. {
  393. AZ_Warning("JointsManipulationComponent", false, "No manipulation joints to handle!");
  394. AZ::TickBus::Handler::BusDisconnect();
  395. return;
  396. }
  397. m_jointStatePublisher->InitializePublisher();
  398. }
  399. auto simTimestamp = ROS2::ROS2Interface::Get()->GetROSTimestamp();
  400. float deltaSimTime = ROS2::ROS2Conversions::GetTimeDifference(m_lastTickTimestamp, simTimestamp);
  401. MoveToSetPositions(deltaSimTime);
  402. m_lastTickTimestamp = simTimestamp;
  403. }
  404. } // namespace ROS2Controllers