Преглед изворни кода

[ROS 2] Lidar segmentation (#754)

Signed-off-by: Aleksander Kamiński <[email protected]>
Signed-off-by: Krzysztof Rymski <[email protected]>
Signed-off-by: Michał Pełka <[email protected]>
Co-authored-by: Krzysztof Rymski <[email protected]>
Co-authored-by: Michał Pełka <[email protected]>
Aleksander Kamiński пре 9 месеци
родитељ
комит
0a74c155ae
24 измењених фајлова са 726 додато и 54 уклоњено
  1. 1 1
      Gems/ROS2/Code/CMakeLists.txt
  2. 60 0
      Gems/ROS2/Code/Include/ROS2/Lidar/ClassSegmentationBus.h
  3. 11 8
      Gems/ROS2/Code/Include/ROS2/Lidar/LidarRegistrarBus.h
  4. 21 0
      Gems/ROS2/Code/Include/ROS2/Lidar/RaycastResults.h
  5. 40 0
      Gems/ROS2/Code/Include/ROS2/Lidar/SegmentationClassConfiguration.h
  6. 22 0
      Gems/ROS2/Code/Include/ROS2/Lidar/SegmentationUtils.h
  7. 181 0
      Gems/ROS2/Code/Source/Lidar/ClassSegmentationConfigurationComponent.cpp
  8. 50 0
      Gems/ROS2/Code/Source/Lidar/ClassSegmentationConfigurationComponent.h
  9. 37 16
      Gems/ROS2/Code/Source/Lidar/LidarCore.cpp
  10. 5 0
      Gems/ROS2/Code/Source/Lidar/LidarCore.h
  11. 65 10
      Gems/ROS2/Code/Source/Lidar/LidarRaycaster.cpp
  12. 5 0
      Gems/ROS2/Code/Source/Lidar/LidarRaycaster.h
  13. 14 1
      Gems/ROS2/Code/Source/Lidar/LidarSensorConfiguration.cpp
  14. 6 5
      Gems/ROS2/Code/Source/Lidar/LidarSensorConfiguration.h
  15. 2 2
      Gems/ROS2/Code/Source/Lidar/LidarSystem.cpp
  16. 1 0
      Gems/ROS2/Code/Source/Lidar/PointCloudMessageBuilder.h
  17. 85 10
      Gems/ROS2/Code/Source/Lidar/ROS2LidarSensorComponent.cpp
  18. 2 0
      Gems/ROS2/Code/Source/Lidar/ROS2LidarSensorComponent.h
  19. 5 0
      Gems/ROS2/Code/Source/Lidar/RaycastResults.cpp
  20. 45 0
      Gems/ROS2/Code/Source/Lidar/SegmentationClassConfiguration.cpp
  21. 58 0
      Gems/ROS2/Code/Source/Lidar/SegmentationUtils.cpp
  22. 2 0
      Gems/ROS2/Code/Source/ROS2ModuleInterface.h
  23. 4 0
      Gems/ROS2/Code/ros2_files.cmake
  24. 4 1
      Gems/ROS2/Code/ros2_header_files.cmake

+ 1 - 1
Gems/ROS2/Code/CMakeLists.txt

@@ -69,7 +69,7 @@ ly_add_target(
             Gem::LmbrCentral.API
 )
 
-target_depends_on_ros2_packages(${gem_name}.Static rclcpp builtin_interfaces std_msgs sensor_msgs nav_msgs tf2_ros ackermann_msgs gazebo_msgs)
+target_depends_on_ros2_packages(${gem_name}.Static rclcpp builtin_interfaces std_msgs sensor_msgs nav_msgs tf2_ros ackermann_msgs gazebo_msgs vision_msgs)
 target_depends_on_ros2_package(${gem_name}.Static control_toolbox 2.2.0 REQUIRED)
 
 ly_add_target(

+ 60 - 0
Gems/ROS2/Code/Include/ROS2/Lidar/ClassSegmentationBus.h

@@ -0,0 +1,60 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+#pragma once
+
+#include <AzCore/Component/EntityId.h>
+#include <AzCore/EBus/EBus.h>
+#include <AzCore/Interface/Interface.h>
+#include <AzCore/Math/Color.h>
+#include <LmbrCentral/Scripting/TagComponentBus.h>
+#include <ROS2/Lidar/SegmentationClassConfiguration.h>
+
+namespace ROS2
+{
+    static constexpr uint8_t UnknownClassId = 0U;
+    static constexpr uint8_t TerrainClassId = 1U;
+    using SegmentationClassConfigList = AZStd::vector<SegmentationClassConfiguration>;
+
+    //! Interface class that allows for retrieval of segmentation class information.
+    class ClassSegmentationRequests
+    {
+    public:
+        AZ_RTTI(ClassSegmentationRequests, "{69b4109e-25ff-482f-b92e-f19cdf06bce2}");
+
+        //! Returns the color of segmentation class with the provided class ID.
+        //! If no segmentation class is found with provided class ID, returns AZ::Colors::White.
+        //! @param classId Class ID of the segmentation class.
+        //! @return Color of the class with provided ID.
+        virtual AZ::Color GetClassColor(uint8_t classId) const = 0;
+
+        //! If segmentation class exists that is associated with provided tag,
+        //! returns ID of this class. Otherwise, returns AZStd::nullopt;
+        //! @param tag Tag associated with the segmentation class.
+        //! @return ID of found class or AZStd::nullopt.
+        virtual AZStd::optional<uint8_t> GetClassIdForTag(LmbrCentral::Tag tag) const = 0;
+
+        //! Returns a reference to the segmentation config list.
+        virtual const SegmentationClassConfigList& GetClassConfigList() const = 0;
+
+    protected:
+        virtual ~ClassSegmentationRequests() = default;
+    };
+
+    class ClassSegmentationRequestBusTraits : public AZ::EBusTraits
+    {
+    public:
+        //////////////////////////////////////////////////////////////////////////
+        // EBusTraits overrides
+        static constexpr AZ::EBusHandlerPolicy HandlerPolicy = AZ::EBusHandlerPolicy::Single;
+        static constexpr AZ::EBusAddressPolicy AddressPolicy = AZ::EBusAddressPolicy::Single;
+        //////////////////////////////////////////////////////////////////////////
+    };
+
+    using ClassSegmentationRequestBus = AZ::EBus<ClassSegmentationRequests, ClassSegmentationRequestBusTraits>;
+    using ClassSegmentationInterface = AZ::Interface<ClassSegmentationRequests>;
+} // namespace ROS2

+ 11 - 8
Gems/ROS2/Code/Include/ROS2/Lidar/LidarRegistrarBus.h

@@ -16,14 +16,17 @@ namespace ROS2
     //! Enum bitwise flags used to describe LidarSystem's feature support.
     enum LidarSystemFeatures : uint16_t
     {
-        None                    = 0,
-        Noise                   = 1,
-        CollisionLayers         = 1 << 1,
-        EntityExclusion         = 1 << 2,
-        MaxRangePoints          = 1 << 3,
-        PointcloudPublishing    = 1 << 4,
-        Intensity               = 1 << 5,
-        All                     = 0b1111111111111111,
+        // clang-format off
+        None =                  0,
+        Noise =                 1,
+        CollisionLayers =       1 << 1,
+        EntityExclusion =       1 << 2,
+        MaxRangePoints =        1 << 3,
+        PointcloudPublishing =  1 << 4,
+        Intensity =             1 << 5,
+        Segmentation =          1 << 6,
+        All =                   (1 << 7) - 1, // All feature bits enabled.
+        // clang-format on
     };
 
     //! Structure used to hold LidarSystem's metadata.

+ 21 - 0
Gems/ROS2/Code/Include/ROS2/Lidar/RaycastResults.h

@@ -18,6 +18,7 @@ namespace ROS2
         Point = (1 << 0), //!< return 3D point coordinates
         Range = (1 << 1), //!< return array of distances
         Intensity = (1 << 2), //!< return intensity data
+        SegmentationData = (1 << 3), //!< return segmentation data
     };
 
     //! Bitwise operators for RaycastResultFlags
@@ -49,6 +50,18 @@ namespace ROS2
         using Type = float;
     };
 
+    struct SegmentationIds
+    {
+        int32_t m_entityId;
+        AZ::u8 m_classId;
+    };
+
+    template<>
+    struct ResultTraits<RaycastResultFlags::SegmentationData>
+    {
+        using Type = SegmentationIds;
+    };
+
     //! Class used for storing the results of a raycast.
     //! It guarantees a uniform length of all its fields.
     class RaycastResults
@@ -106,6 +119,7 @@ namespace ROS2
         FieldInternal<RaycastResultFlags::Point> m_points;
         FieldInternal<RaycastResultFlags::Range> m_ranges;
         FieldInternal<RaycastResultFlags::Intensity> m_intensities;
+        FieldInternal<RaycastResultFlags::SegmentationData> m_segmentationData;
     };
 
     template<RaycastResultFlags F>
@@ -157,6 +171,13 @@ namespace ROS2
         return m_intensities;
     }
 
+    template<>
+    inline const RaycastResults::FieldInternal<RaycastResultFlags::SegmentationData>& RaycastResults::GetField<
+        RaycastResultFlags::SegmentationData>() const
+    {
+        return m_segmentationData;
+    }
+
     template<RaycastResultFlags F>
     RaycastResults::FieldInternal<F>& RaycastResults::GetField()
     {

+ 40 - 0
Gems/ROS2/Code/Include/ROS2/Lidar/SegmentationClassConfiguration.h

@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+#pragma once
+
+#include <AzCore/Component/EntityId.h>
+#include <AzCore/RTTI/RTTI.h>
+#include <AzCore/Serialization/SerializeContext.h>
+#include <AzCore/std/string/string.h>
+
+namespace ROS2
+{
+    //! A structure capturing configuration of a segmentation class.
+    class SegmentationClassConfiguration
+    {
+    public:
+        AZ_TYPE_INFO(SegmentationClassConfiguration, "{e46e75f4-1e0e-48ca-a22f-43afc8f25133}");
+        static void Reflect(AZ::ReflectContext* context);
+
+        static const SegmentationClassConfiguration UnknownClass;
+        static const SegmentationClassConfiguration GroundClass;
+
+        SegmentationClassConfiguration() = default;
+
+        SegmentationClassConfiguration(const AZStd::string& className, const uint8_t classId, const AZ::Color& classColor)
+            : m_className(className)
+            , m_classId(classId)
+            , m_classColor(classColor)
+        {
+        }
+
+        AZStd::string m_className = "Default";
+        uint8_t m_classId = 0;
+        AZ::Color m_classColor = AZ::Color(1.0f, 1.0f, 1.0f, 1.0f);
+    };
+} // namespace ROS2

+ 22 - 0
Gems/ROS2/Code/Include/ROS2/Lidar/SegmentationUtils.h

@@ -0,0 +1,22 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+#pragma once
+
+#include <AzCore/Component/EntityId.h>
+
+namespace ROS2::SegmentationUtils
+{
+    //! Returns the segmentation class ID of the entity with provided ID.
+    //! Entity's class ID is fetched using the Tag component (@see Tag).
+    //! If this entity has a tag with a name that matches an existing
+    //! segmentation class (configured through the Class Segmentation component),
+    //! the ID of this class is returned. Otherwise, the Unknown Class ID is returned.
+    //! @param entityId ID of the entity for which a class ID is to be fetched.
+    //! @return Class ID of the entity.
+    [[nodiscard]] uint8_t FetchClassIdForEntity(AZ::EntityId entityId);
+} // namespace ROS2::SegmentationUtils

+ 181 - 0
Gems/ROS2/Code/Source/Lidar/ClassSegmentationConfigurationComponent.cpp

@@ -0,0 +1,181 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+#include <AzCore/Serialization/EditContext.h>
+#include <AzCore/Serialization/SerializeContext.h>
+#include <Lidar/ClassSegmentationConfigurationComponent.h>
+
+namespace ROS2
+{
+    void ClassSegmentationConfigurationComponent::Reflect(AZ::ReflectContext* context)
+    {
+        SegmentationClassConfiguration::Reflect(context);
+        if (auto* serializeContext = azrtti_cast<AZ::SerializeContext*>(context))
+        {
+            serializeContext->Class<ClassSegmentationConfigurationComponent, AZ::Component>()->Version(0)->Field(
+                "Segmentation Classes", &ClassSegmentationConfigurationComponent::m_segmentationClasses);
+
+            if (auto* editContext = serializeContext->GetEditContext())
+            {
+                editContext->Class<ClassSegmentationConfigurationComponent>("Class Segmentation Configuration", "")
+                    ->ClassElement(AZ::Edit::ClassElements::EditorData, "")
+                    ->Attribute(AZ::Edit::Attributes::Category, "ROS2")
+                    ->Attribute(AZ::Edit::Attributes::AppearsInAddComponentMenu, AZStd::vector<AZ::Crc32>({ AZ_CRC_CE("Level") }))
+                    ->DataElement(
+                        AZ::Edit::UIHandlers::Default,
+                        &ClassSegmentationConfigurationComponent::m_segmentationClasses,
+                        "Segmentation classes",
+                        "Segmentation classes and their colors.")
+                    ->Attribute(AZ::Edit::Attributes::AutoExpand, true)
+                    ->Attribute(
+                        AZ::Edit::Attributes::ChangeValidate, &ClassSegmentationConfigurationComponent::ValidateSegmentationClasses)
+                    ->Attribute(
+                        AZ::Edit::Attributes::ChangeNotify, &ClassSegmentationConfigurationComponent::OnSegmentationClassesChanged);
+            }
+        }
+    }
+
+    void ClassSegmentationConfigurationComponent::GetProvidedServices(AZ::ComponentDescriptor::DependencyArrayType& provided)
+    {
+        provided.push_back(AZ_CRC_CE("ClassSegmentationConfig"));
+    }
+
+    void ClassSegmentationConfigurationComponent::GetIncompatibleServices(AZ::ComponentDescriptor::DependencyArrayType& incompatible)
+    {
+        incompatible.push_back(AZ_CRC_CE("ClassSegmentationConfig"));
+    }
+
+    AZ::Color ClassSegmentationConfigurationComponent::GetClassColor(uint8_t classId) const
+    {
+        auto it = m_classIdToColor.find(classId);
+        if (it == m_classIdToColor.end())
+        {
+            return AZ::Colors::White;
+        }
+
+        return it->second;
+    }
+
+    AZStd::optional<uint8_t> ClassSegmentationConfigurationComponent::GetClassIdForTag(LmbrCentral::Tag tag) const
+    {
+        auto it = m_tagToClassId.find(tag);
+        if (it == m_tagToClassId.end())
+        {
+            return AZStd::nullopt;
+        }
+
+        return it->second;
+    }
+
+    const SegmentationClassConfigList& ClassSegmentationConfigurationComponent::GetClassConfigList() const
+    {
+        return m_segmentationClasses;
+    }
+
+    void ClassSegmentationConfigurationComponent::Activate()
+    {
+        ConstructSegmentationClassMaps();
+
+        if (!ClassSegmentationInterface::Get())
+        {
+            ClassSegmentationInterface::Register(this);
+        }
+    }
+
+    void ClassSegmentationConfigurationComponent::Deactivate()
+    {
+        if (ClassSegmentationInterface::Get() == this)
+        {
+            ClassSegmentationInterface::Unregister(this);
+        }
+
+        m_classIdToColor.clear();
+        m_tagToClassId.clear();
+    }
+
+    AZ::Outcome<void, AZStd::string> ClassSegmentationConfigurationComponent::ValidateSegmentationClasses(
+        void* newValue, const AZ::TypeId& valueType) const
+    {
+        AZ_Assert(azrtti_typeid<SegmentationClassConfigList>() == valueType, "Unexpected value type");
+        if (azrtti_typeid<SegmentationClassConfigList>() != valueType)
+        {
+            return AZ::Failure(AZStd::string("Unexpectedly received an invalid type as segmentation classes!"));
+        }
+
+        const auto& segmentationClasses = *azrtti_cast<SegmentationClassConfigList*>(newValue);
+
+        bool unknownPresent{ false }, groundPresent{ false };
+        for (const auto& segmentationClass : segmentationClasses)
+        {
+            if (segmentationClass.m_classId == SegmentationClassConfiguration::UnknownClass.m_classId &&
+                segmentationClass.m_className == SegmentationClassConfiguration::UnknownClass.m_className)
+            {
+                unknownPresent = true;
+            }
+
+            if (segmentationClass.m_classId == SegmentationClassConfiguration::GroundClass.m_classId &&
+                segmentationClass.m_className == SegmentationClassConfiguration::GroundClass.m_className)
+            {
+                groundPresent = true;
+            }
+        }
+
+        if (!unknownPresent || !groundPresent)
+        {
+            return AZ::Failure(
+                AZStd::string::format("Segmentation class with name %s must exist.", (!unknownPresent ? "Unknown" : "Ground")));
+        }
+
+        return AZ::Success();
+    }
+
+    void ClassSegmentationConfigurationComponent::ConstructSegmentationClassMaps()
+    {
+        m_classIdToColor.reserve(m_segmentationClasses.size());
+        m_tagToClassId.reserve(m_segmentationClasses.size());
+
+        for (const auto& segmentationClass : m_segmentationClasses)
+        {
+            m_classIdToColor.insert(AZStd::make_pair(segmentationClass.m_classId, segmentationClass.m_classColor));
+            m_tagToClassId.insert(AZStd::make_pair(LmbrCentral::Tag(segmentationClass.m_className), segmentationClass.m_classId));
+        }
+    }
+
+    AZ::Crc32 ClassSegmentationConfigurationComponent::OnSegmentationClassesChanged()
+    {
+        bool unknownPresent{ false }, groundPresent{ false };
+        for (auto& segmentationClass : m_segmentationClasses)
+        {
+            if (segmentationClass.m_classId == SegmentationClassConfiguration::UnknownClass.m_classId)
+            {
+                unknownPresent = true;
+            }
+
+            if (segmentationClass.m_classId == SegmentationClassConfiguration::GroundClass.m_classId)
+            {
+                groundPresent = true;
+            }
+        }
+
+        if (unknownPresent && groundPresent)
+        {
+            return AZ::Edit::PropertyRefreshLevels::None;
+        }
+
+        if (!unknownPresent)
+        {
+            m_segmentationClasses.push_back(SegmentationClassConfiguration::UnknownClass);
+        }
+
+        if (!groundPresent)
+        {
+            m_segmentationClasses.push_back(SegmentationClassConfiguration::GroundClass);
+        }
+
+        return AZ::Edit::PropertyRefreshLevels::EntireTree;
+    }
+} // namespace ROS2

+ 50 - 0
Gems/ROS2/Code/Source/Lidar/ClassSegmentationConfigurationComponent.h

@@ -0,0 +1,50 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+#pragma once
+
+#include <AzCore/Component/Component.h>
+#include <AzCore/Math/Color.h>
+#include <ROS2/Lidar/ClassSegmentationBus.h>
+#include <ROS2/Lidar/SegmentationClassConfiguration.h>
+
+namespace ROS2
+{
+    class ClassSegmentationConfigurationComponent
+        : public AZ::Component
+        , ClassSegmentationRequestBus::Handler
+    {
+    public:
+        AZ_COMPONENT(ClassSegmentationConfigurationComponent, "{bab1ea0c-7456-40ea-bc1e-71697137c27c}", AZ::Component);
+
+        ClassSegmentationConfigurationComponent() = default;
+        ~ClassSegmentationConfigurationComponent() override = default;
+
+        static void Reflect(AZ::ReflectContext* context);
+
+        static void GetProvidedServices(AZ::ComponentDescriptor::DependencyArrayType& provided);
+        static void GetIncompatibleServices(AZ::ComponentDescriptor::DependencyArrayType& incompatible);
+
+        // ClassSegmentationRequestBus overrides
+        AZ::Color GetClassColor(uint8_t classId) const;
+        AZStd::optional<uint8_t> GetClassIdForTag(LmbrCentral::Tag tag) const;
+        const SegmentationClassConfigList& GetClassConfigList() const;
+
+        // AZ::Component overrides
+        void Activate() override;
+        void Deactivate() override;
+
+    private:
+        AZ::Outcome<void, AZStd::string> ValidateSegmentationClasses(void* newValue, const AZ::TypeId& valueType) const;
+        void ConstructSegmentationClassMaps();
+        AZ::Crc32 OnSegmentationClassesChanged();
+
+        SegmentationClassConfigList m_segmentationClasses;
+        AZStd::unordered_map<LmbrCentral::Tag, uint8_t> m_tagToClassId;
+        AZStd::unordered_map<uint8_t, AZ::Color> m_classIdToColor;
+    };
+} // namespace ROS2

+ 37 - 16
Gems/ROS2/Code/Source/Lidar/LidarCore.cpp

@@ -6,12 +6,13 @@
  *
  */
 
-#include "LidarCore.h"
 #include <Atom/RPI.Public/AuxGeom/AuxGeomFeatureProcessorInterface.h>
 #include <Atom/RPI.Public/Scene.h>
 #include <AzFramework/Physics/PhysicsSystem.h>
+#include <Lidar/LidarCore.h>
 #include <Lidar/LidarRegistrarSystemComponent.h>
 #include <ROS2/Frame/ROS2FrameComponent.h>
+#include <ROS2/Lidar/ClassSegmentationBus.h>
 #include <ROS2/ROS2Bus.h>
 #include <ROS2/Utilities/ROS2Names.h>
 
@@ -36,6 +37,34 @@ namespace ROS2
         }
     }
 
+    RaycastResultFlags LidarCore::GetRaycastResultFlagsForConfig(const LidarSensorConfiguration& configuration)
+    {
+        RaycastResultFlags flags = RaycastResultFlags::Range | RaycastResultFlags::Point;
+        if (configuration.m_lidarSystemFeatures & LidarSystemFeatures::Intensity)
+        {
+            flags |= RaycastResultFlags::Intensity;
+        }
+
+        if (configuration.m_lidarSystemFeatures & LidarSystemFeatures::Segmentation && configuration.m_isSegmentationEnabled)
+        {
+            if (ClassSegmentationInterface::Get())
+            {
+                flags |= RaycastResultFlags::SegmentationData;
+            }
+            else
+            {
+                AZ_Error(
+                    "ROS2",
+                    false,
+                    "Segmentation feature was enabled for this lidar sensor but the segmentation interface is not accessible. Make sure to "
+                    "either add the Class segmentation component to the level entity or disable the feature in the lidar component "
+                    "configuration.");
+            }
+        }
+
+        return flags;
+    }
+
     void LidarCore::ConnectToLidarRaycaster()
     {
         if (auto raycasterId = m_implementationToRaycasterMap.find(m_lidarConfiguration.m_lidarSystem);
@@ -72,10 +101,8 @@ namespace ROS2
                 m_lidarConfiguration.m_lidarParameters.m_noiseParameters.m_distanceNoiseStdDevRisePerMeter);
         }
 
-        LidarRaycasterRequestBus::Event(
-            m_lidarRaycasterId,
-            &LidarRaycasterRequestBus::Events::ConfigureRaycastResultFlags,
-            GetRaycastResultFlagsForConfig(m_lidarConfiguration));
+        m_resultFlags = GetRaycastResultFlagsForConfig(m_lidarConfiguration);
+        LidarRaycasterRequestBus::Event(m_lidarRaycasterId, &LidarRaycasterRequestBus::Events::ConfigureRaycastResultFlags, m_resultFlags);
 
         if (m_lidarConfiguration.m_lidarSystemFeatures & LidarSystemFeatures::CollisionLayers)
         {
@@ -106,17 +133,6 @@ namespace ROS2
         m_lastPoints.assign(pointsField.begin(), pointsField.end());
     }
 
-    RaycastResultFlags LidarCore::GetRaycastResultFlagsForConfig(const LidarSensorConfiguration& configuration)
-    {
-        RaycastResultFlags flags = RaycastResultFlags::Range | RaycastResultFlags::Point;
-        if (configuration.m_lidarSystemFeatures & LidarSystemFeatures::Intensity)
-        {
-            flags |= RaycastResultFlags::Intensity;
-        }
-
-        return flags;
-    }
-
     LidarCore::LidarCore(const AZStd::vector<LidarTemplate::LidarModel>& availableModels)
         : m_lidarConfiguration(availableModels)
     {
@@ -177,6 +193,11 @@ namespace ROS2
         return m_lidarRaycasterId;
     }
 
+    RaycastResultFlags LidarCore::GetResultFlags() const
+    {
+        return m_resultFlags;
+    }
+
     AZStd::optional<RaycastResults> LidarCore::PerformRaycast()
     {
         AZ::Entity* entity = nullptr;

+ 5 - 0
Gems/ROS2/Code/Source/Lidar/LidarCore.h

@@ -45,6 +45,10 @@ namespace ROS2
         //! @return Used raycaster's id.
         LidarId GetLidarRaycasterId() const;
 
+        //! Get the result flags used by this lidar.
+        //! @return Used result flags.
+        RaycastResultFlags GetResultFlags() const;
+
         //! Configuration according to which the lidar performs its raycasts.
         LidarSensorConfiguration m_lidarConfiguration;
 
@@ -66,5 +70,6 @@ namespace ROS2
         AZStd::vector<AZ::Vector3> m_lastPoints;
 
         AZ::EntityId m_entityId;
+        RaycastResultFlags m_resultFlags;
     };
 } // namespace ROS2

+ 65 - 10
Gems/ROS2/Code/Source/Lidar/LidarRaycaster.cpp

@@ -14,6 +14,7 @@
 #include <AzFramework/Physics/Shape.h>
 #include <Lidar/LidarRaycaster.h>
 #include <Lidar/LidarTemplateUtils.h>
+#include <ROS2/Lidar/SegmentationUtils.h>
 
 namespace ROS2
 {
@@ -112,6 +113,18 @@ namespace ROS2
         return requests;
     }
 
+    uint8_t LidarRaycaster::GetClassIdForEntity(AZ::EntityId entityId)
+    {
+        if (auto it = m_entityIdToClassIdCache.find(entityId); it != m_entityIdToClassIdCache.end())
+        {
+            return it->second;
+        }
+
+        const uint8_t classId = SegmentationUtils::FetchClassIdForEntity(entityId);
+        m_entityIdToClassIdCache.emplace(entityId, classId);
+        return classId;
+    }
+
     AZ::Outcome<RaycastResults, const char*> LidarRaycaster::PerformRaycast(const AZ::Transform& lidarTransform)
     {
         AZ_Assert(!m_rayRotations.empty(), "Ray poses are not configured. Unable to Perform a raycast.");
@@ -127,10 +140,12 @@ namespace ROS2
 
         const bool handlePoints = (m_resultFlags & RaycastResultFlags::Point) == RaycastResultFlags::Point;
         const bool handleRanges = (m_resultFlags & RaycastResultFlags::Range) == RaycastResultFlags::Range;
+        const bool handleSegmentation = (m_resultFlags & RaycastResultFlags::SegmentationData) == RaycastResultFlags::SegmentationData;
         RaycastResults results(m_resultFlags, rayDirections.size());
 
         AZStd::optional<RaycastResults::FieldSpan<RaycastResultFlags::Point>::iterator> pointIt;
         AZStd::optional<RaycastResults::FieldSpan<RaycastResultFlags::Range>::iterator> rangeIt;
+        AZStd::optional<RaycastResults::FieldSpan<RaycastResultFlags::SegmentationData>::iterator> segmentationIt;
         if (handlePoints)
         {
             pointIt = results.GetFieldSpan<RaycastResultFlags::Point>().value().begin();
@@ -139,6 +154,10 @@ namespace ROS2
         {
             rangeIt = results.GetFieldSpan<RaycastResultFlags::Range>().value().begin();
         }
+        if (handleSegmentation)
+        {
+            segmentationIt = results.GetFieldSpan<RaycastResultFlags::SegmentationData>().value().begin();
+        }
 
         auto* sceneInterface = AZ::Interface<AzPhysics::SceneInterface>::Get();
         auto requestResults = sceneInterface->QuerySceneBatch(m_sceneHandle, requests);
@@ -157,35 +176,65 @@ namespace ROS2
                 hitRange = -AZStd::numeric_limits<float>::infinity();
             }
 
-            bool assigned = false;
-            if (handleRanges)
+            bool wasUsed = false;
+            if (rangeIt.has_value())
             {
                 *rangeIt.value() = hitRange;
-                ++rangeIt.value();
-                assigned = true;
+                wasUsed = true;
             }
 
-            if (handlePoints)
+            if (pointIt.has_value())
             {
                 if (hitRange == maxRange)
                 {
                     // to properly visualize max points they need to be transformed to local coordinate system before applying maxRange
                     const AZ::Vector3 maxPoint = lidarTransform.TransformPoint(localTransform.TransformVector(rayDirections[i]) * hitRange);
                     *pointIt.value() = maxPoint;
-                    ++pointIt.value();
-                    assigned = true;
+                    wasUsed = true;
                 }
                 else if (!AZStd::isinf(hitRange))
                 {
                     // otherwise they are already calculated by PhysX
                     *pointIt.value() = requestResult.m_hits[0].m_position;
-                    ++pointIt.value();
-                    assigned = true;
+                    wasUsed = true;
                 }
             }
 
-            if (assigned)
+            if (segmentationIt.has_value())
             {
+                segmentationIt.value()->m_classId = 0;
+                segmentationIt.value()->m_entityId = 0;
+
+                if (requestResult)
+                {
+                    const auto entityId = requestResult.m_hits[0].m_entityId;
+                    const uint8_t classId = GetClassIdForEntity(entityId);
+                    const int32_t compressedEntityId = CompressEntityId(entityId);
+
+                    segmentationIt.value()->m_classId = classId;
+                    segmentationIt.value()->m_entityId = compressedEntityId;
+                }
+
+                wasUsed = true;
+            }
+
+            if (wasUsed)
+            {
+                if (rangeIt.has_value())
+                {
+                    ++rangeIt.value();
+                }
+
+                if (pointIt.has_value())
+                {
+                    ++pointIt.value();
+                }
+
+                if (segmentationIt.has_value())
+                {
+                    ++segmentationIt.value();
+                }
+
                 ++usedSize;
             }
         }
@@ -202,4 +251,10 @@ namespace ROS2
     {
         m_addMaxRangePoints = addMaxRangePoints;
     }
+
+    int32_t LidarRaycaster::CompressEntityId(AZ::EntityId entityId)
+    {
+        // Mapping the 64 bit entity ID onto a 32 integer may result in collisions but the chances are slim.
+        return (aznumeric_cast<AZ::u64>(entityId) >> 32) ^ (aznumeric_cast<AZ::u64>(entityId) & 0xFFFFFFFF);
+    }
 } // namespace ROS2

+ 5 - 0
Gems/ROS2/Code/Source/Lidar/LidarRaycaster.h

@@ -37,8 +37,12 @@ namespace ROS2
         void ConfigureMaxRangePointAddition(bool addMaxRangePoints) override;
 
     private:
+        static int32_t CompressEntityId(AZ::EntityId entityId);
+
         AzPhysics::SceneQueryRequests prepareRequests(
             const AZ::Transform& lidarTransform, const AZStd::vector<AZ::Vector3>& rayDirections) const;
+        [[nodiscard]] uint8_t GetClassIdForEntity(AZ::EntityId entityId);
+
         LidarId m_busId;
         //! EntityId that is used to acquire the physics scene handle.
         AZ::EntityId m_sceneEntityId;
@@ -50,5 +54,6 @@ namespace ROS2
         AZStd::vector<AZ::Quaternion> m_rayRotations{ { AZ::Quaternion::CreateZero() } };
 
         AZStd::unordered_set<AZ::u32> m_ignoredCollisionLayers;
+        AZStd::unordered_map<AZ::EntityId, uint8_t> m_entityIdToClassIdCache;
     };
 } // namespace ROS2

+ 14 - 1
Gems/ROS2/Code/Source/Lidar/LidarSensorConfiguration.cpp

@@ -17,12 +17,13 @@ namespace ROS2
         if (auto serializeContext = azrtti_cast<AZ::SerializeContext*>(context))
         {
             serializeContext->Class<LidarSensorConfiguration>()
-                ->Version(1)
+                ->Version(2)
                 ->Field("lidarModelName", &LidarSensorConfiguration::m_lidarModelName)
                 ->Field("lidarImplementation", &LidarSensorConfiguration::m_lidarSystem)
                 ->Field("LidarParameters", &LidarSensorConfiguration::m_lidarParameters)
                 ->Field("IgnoredLayerIndices", &LidarSensorConfiguration::m_ignoredCollisionLayers)
                 ->Field("ExcludedEntities", &LidarSensorConfiguration::m_excludedEntities)
+                ->Field("IsSegmentationEnabled", &LidarSensorConfiguration::m_isSegmentationEnabled)
                 ->Field("PointsAtMax", &LidarSensorConfiguration::m_addPointsAtMax);
 
             if (AZ::EditContext* ec = serializeContext->GetEditContext())
@@ -58,6 +59,13 @@ namespace ROS2
                     ->Attribute(AZ::Edit::Attributes::AutoExpand, true)
                     ->Attribute(AZ::Edit::Attributes::ContainerCanBeModified, true)
                     ->Attribute(AZ::Edit::Attributes::Visibility, &LidarSensorConfiguration::IsEntityExclusionVisible)
+                    ->DataElement(
+                        AZ::Edit::UIHandlers::Default,
+                        &LidarSensorConfiguration::m_isSegmentationEnabled,
+                        "Enable Segmentation",
+                        "Enable point cloud segmentation. Note: Make sure to add the Class Segmentation Configuration Component to the "
+                        "level entity.")
+                    ->Attribute(AZ::Edit::Attributes::Visibility, &LidarSensorConfiguration::IsSegmentationConfigurationVisible)
                     ->DataElement(
                         AZ::Edit::UIHandlers::Default,
                         &LidarSensorConfiguration::m_addPointsAtMax,
@@ -140,6 +148,11 @@ namespace ROS2
         return m_lidarSystemFeatures & LidarSystemFeatures::MaxRangePoints;
     }
 
+    bool LidarSensorConfiguration::IsSegmentationConfigurationVisible() const
+    {
+        return m_lidarSystemFeatures & LidarSystemFeatures::Segmentation;
+    }
+
     AZ::Crc32 LidarSensorConfiguration::OnLidarModelSelected()
     {
         FetchLidarModelConfiguration();

+ 6 - 5
Gems/ROS2/Code/Source/Lidar/LidarSensorConfiguration.h

@@ -7,14 +7,14 @@
  */
 #pragma once
 
+#include "ROS2/Lidar/LidarRegistrarBus.h"
 #include <AzCore/Component/EntityId.h>
 #include <AzCore/RTTI/RTTI.h>
 #include <AzCore/Serialization/SerializeContext.h>
 #include <AzCore/std/string/string.h>
-
-#include "LidarRegistrarSystemComponent.h"
-#include "LidarTemplate.h"
-#include "LidarTemplateUtils.h"
+#include <Lidar/LidarRegistrarSystemComponent.h>
+#include <Lidar/LidarTemplate.h>
+#include <Lidar/LidarTemplateUtils.h>
 
 namespace ROS2
 {
@@ -39,6 +39,7 @@ namespace ROS2
         AZStd::unordered_set<AZ::u32> m_ignoredCollisionLayers;
         AZStd::vector<AZ::EntityId> m_excludedEntities;
 
+        bool m_isSegmentationEnabled = false;
         bool m_addPointsAtMax = false;
 
     private:
@@ -46,7 +47,7 @@ namespace ROS2
         bool IsIgnoredLayerConfigurationVisible() const;
         bool IsEntityExclusionVisible() const;
         bool IsMaxPointsConfigurationVisible() const;
-
+        bool IsSegmentationConfigurationVisible() const;
         //! Update the lidar configuration based on the current lidar model selected.
         void FetchLidarModelConfiguration();
 

+ 2 - 2
Gems/ROS2/Code/Source/Lidar/LidarSystem.cpp

@@ -25,8 +25,8 @@ namespace ROS2
     void LidarSystem::Activate()
     {
         static constexpr const char* Description = "Collider-based lidar implementation that uses the PhysX engine's raycasting.";
-        static constexpr auto SupportedFeatures =
-            aznumeric_cast<LidarSystemFeatures>(LidarSystemFeatures::CollisionLayers | LidarSystemFeatures::MaxRangePoints);
+        static constexpr auto SupportedFeatures = aznumeric_cast<LidarSystemFeatures>(
+            LidarSystemFeatures::CollisionLayers | LidarSystemFeatures::MaxRangePoints | LidarSystemFeatures::Segmentation);
 
         LidarSystemRequestBus::Handler::BusConnect(AZ_CRC(SystemName));
 

+ 1 - 0
Gems/ROS2/Code/Source/Lidar/PointCloudMessageBuilder.h

@@ -9,6 +9,7 @@
 
 #include <sensor_msgs/msg/point_cloud2.hpp>
 #include <AzCore/std/optional.h>
+#include <AzCore/std/string/string.h>
 
 namespace ROS2
 {

+ 85 - 10
Gems/ROS2/Code/Source/Lidar/ROS2LidarSensorComponent.cpp

@@ -12,6 +12,7 @@
 #include <Lidar/PointCloudMessageBuilder.h>
 #include <Lidar/ROS2LidarSensorComponent.h>
 #include <ROS2/Frame/ROS2FrameComponent.h>
+#include <ROS2/Lidar/ClassSegmentationBus.h>
 #include <ROS2/Utilities/ROS2Names.h>
 #include <sensor_msgs/point_cloud2_iterator.hpp>
 
@@ -103,6 +104,13 @@ namespace ROS2
             const TopicConfiguration& publisherConfig = m_sensorConfiguration.m_publishersConfigurations[PointCloudType];
             AZStd::string fullTopic = ROS2Names::GetNamespacedName(GetNamespace(), publisherConfig.m_topic);
             m_pointCloudPublisher = ros2Node->create_publisher<sensor_msgs::msg::PointCloud2>(fullTopic.data(), publisherConfig.GetQoS());
+
+            const auto resultFlags = m_lidarCore.GetResultFlags();
+            if (IsFlagEnabled(RaycastResultFlags::SegmentationData, resultFlags))
+            {
+                m_segmentationClassesPublisher = ros2Node->create_publisher<vision_msgs::msg::LabelInfo>(
+                    ROS2Names::GetNamespacedName(GetNamespace(), "segmentation_classes").data(), publisherConfig.GetQoS());
+            }
         }
 
         StartSensor(
@@ -150,10 +158,11 @@ namespace ROS2
         PublishRaycastResults(lastScanResults.value());
     }
 
+    template<typename T>
+    using Pc2MsgIt = sensor_msgs::PointCloud2Iterator<T>;
+
     void ROS2LidarSensorComponent::PublishRaycastResults(const RaycastResults& results)
     {
-        const bool isIntensityEnabled = m_lidarCore.m_lidarConfiguration.m_lidarSystemFeatures & LidarSystemFeatures::Intensity;
-
         auto builder = PointCloud2MessageBuilder(
             GetEntity()->FindComponent<ROS2FrameComponent>()->GetFrameID(), ROS2Interface::Get()->GetROSTimestamp(), results.GetCount());
 
@@ -161,30 +170,66 @@ namespace ROS2
             .AddField("y", sensor_msgs::msg::PointField::FLOAT32)
             .AddField("z", sensor_msgs::msg::PointField::FLOAT32);
 
-        if (isIntensityEnabled)
+        if (results.IsFieldPresent<RaycastResultFlags::Intensity>())
         {
             builder.AddField("intensity", sensor_msgs::msg::PointField::FLOAT32);
         }
-        sensor_msgs::msg::PointCloud2 message = builder.Get();
 
+        if (results.IsFieldPresent<RaycastResultFlags::SegmentationData>())
+        {
+            builder.AddField("entity_id", sensor_msgs::msg::PointField::INT32);
+            builder.AddField("class_id", sensor_msgs::msg::PointField::UINT8);
+            builder.AddField("rgba", sensor_msgs::msg::PointField::UINT32);
+        }
+
+        sensor_msgs::msg::PointCloud2 message = builder.Get();
 
-        sensor_msgs::PointCloud2Iterator<float> messageXIt(message, "x");
-        sensor_msgs::PointCloud2Iterator<float> messageYIt(message, "y");
-        sensor_msgs::PointCloud2Iterator<float> messageZIt(message, "z");
+        Pc2MsgIt<float> messageXIt(message, "x");
+        Pc2MsgIt<float> messageYIt(message, "y");
+        Pc2MsgIt<float> messageZIt(message, "z");
 
         const auto positionField = results.GetConstFieldSpan<RaycastResultFlags::Point>().value();
         auto positionIt = positionField.begin();
 
-        AZStd::optional<sensor_msgs::PointCloud2Iterator<float>> messageIntensityIt;
+        AZStd::optional<Pc2MsgIt<float>> messageIntensityIt;
         AZStd::optional<RaycastResults::FieldSpan<RaycastResultFlags::Intensity>::const_iterator> intensityIt;
-        if (isIntensityEnabled)
+        if (results.IsFieldPresent<RaycastResultFlags::Intensity>())
         {
-            messageIntensityIt = sensor_msgs::PointCloud2Iterator<float>(message, "intensity");
+            messageIntensityIt = Pc2MsgIt<float>(message, "intensity");
             intensityIt = results.GetConstFieldSpan<RaycastResultFlags::Intensity>().value().begin();
         }
 
+        struct MessageSegmentationIterators
+        {
+            Pc2MsgIt<int32_t> m_entityIdIt;
+            Pc2MsgIt<uint8_t> m_classIdIt;
+            Pc2MsgIt<uint32_t> m_rgbaIt;
+        };
+
+        AZStd::optional<MessageSegmentationIterators> messageSegDataIts;
+        AZStd::optional<RaycastResults::FieldSpan<RaycastResultFlags::SegmentationData>::const_iterator> segDataIt;
+        if (results.IsFieldPresent<RaycastResultFlags::SegmentationData>())
+        {
+            messageSegDataIts = MessageSegmentationIterators{
+                Pc2MsgIt<int32_t>(message, "entity_id"),
+                Pc2MsgIt<uint8_t>(message, "class_id"),
+                Pc2MsgIt<uint32_t>(message, "rgba"),
+            };
+
+            segDataIt = results.GetConstFieldSpan<RaycastResultFlags::SegmentationData>().value().begin();
+        }
+
         const auto entityTransform = GetEntity()->FindComponent<AzFramework::TransformComponent>();
         const auto inverseLidarTM = entityTransform->GetWorldTM().GetInverse();
+
+        auto* classSegmentationInterface = ClassSegmentationInterface::Get();
+        AZ_Warning(
+            __func__,
+            !results.IsFieldPresent<RaycastResultFlags::SegmentationData>() || classSegmentationInterface,
+            "Segmentation data was requested but the Class Segmentation interface was unavailable. Unable to fetch segmentation class "
+            "data. Please make sure to add the Class Segmentation Configuration Component to the Level Entity for this feature to work "
+            "properly.");
+
         for (size_t i = 0; i < results.GetCount(); ++i)
         {
             AZ::Vector3 point = *positionIt;
@@ -200,6 +245,22 @@ namespace ROS2
                 ++messageIntensityIt.value();
             }
 
+            if (messageSegDataIts.has_value() && segDataIt.has_value() && classSegmentationInterface)
+            {
+                const ResultTraits<RaycastResultFlags::SegmentationData>::Type segmentationData = *segDataIt.value();
+                *messageSegDataIts->m_entityIdIt = segmentationData.m_entityId;
+                *messageSegDataIts->m_classIdIt = segmentationData.m_classId;
+
+                AZ::Color color = classSegmentationInterface->GetClassColor(segmentationData.m_classId);
+                AZ::u32 rvizColorFormat = color.GetA8() << 24 | color.GetR8() << 16 | color.GetG8() << 8 | color.GetB8();
+                *messageSegDataIts->m_rgbaIt = rvizColorFormat;
+
+                ++segDataIt.value();
+                ++messageSegDataIts->m_entityIdIt;
+                ++messageSegDataIts->m_classIdIt;
+                ++messageSegDataIts->m_rgbaIt;
+            }
+
             ++positionIt;
             ++messageXIt;
             ++messageYIt;
@@ -207,5 +268,19 @@ namespace ROS2
         }
 
         m_pointCloudPublisher->publish(message);
+
+        if (m_segmentationClassesPublisher && classSegmentationInterface)
+        {
+            const auto& segmentationClassConfigList = classSegmentationInterface->GetClassConfigList();
+            vision_msgs::msg::LabelInfo segmentationClasses;
+            for (const auto& segmentationClass : segmentationClassConfigList)
+            {
+                vision_msgs::msg::VisionClass visionClass;
+                visionClass.class_id = segmentationClass.m_classId;
+                visionClass.class_name = segmentationClass.m_className.c_str();
+                segmentationClasses.class_map.push_back(visionClass);
+            }
+            m_segmentationClassesPublisher->publish(segmentationClasses);
+        }
     }
 } // namespace ROS2

+ 2 - 0
Gems/ROS2/Code/Source/Lidar/ROS2LidarSensorComponent.h

@@ -15,6 +15,7 @@
 #include <ROS2/Sensor/ROS2SensorComponentBase.h>
 #include <rclcpp/publisher.hpp>
 #include <sensor_msgs/msg/point_cloud2.hpp>
+#include <vision_msgs/msg/label_info.hpp>
 
 #include "LidarCore.h"
 #include "LidarRaycaster.h"
@@ -48,6 +49,7 @@ namespace ROS2
 
         bool m_canRaycasterPublish = false;
         std::shared_ptr<rclcpp::Publisher<sensor_msgs::msg::PointCloud2>> m_pointCloudPublisher;
+        std::shared_ptr<rclcpp::Publisher<vision_msgs::msg::LabelInfo>> m_segmentationClassesPublisher;
 
         LidarCore m_lidarCore;
 

+ 5 - 0
Gems/ROS2/Code/Source/Lidar/RaycastResults.cpp

@@ -15,6 +15,7 @@ namespace ROS2
         EnsureFlagSatisfied<RaycastResultFlags::Point>(flags, count);
         EnsureFlagSatisfied<RaycastResultFlags::Range>(flags, count);
         EnsureFlagSatisfied<RaycastResultFlags::Intensity>(flags, count);
+        EnsureFlagSatisfied<RaycastResultFlags::SegmentationData>(flags, count);
     }
 
     RaycastResults::RaycastResults(RaycastResults&& other)
@@ -22,6 +23,7 @@ namespace ROS2
         , m_points{ AZStd::move(other.m_points) }
         , m_ranges{ AZStd::move(other.m_ranges) }
         , m_intensities{ AZStd::move(other.m_intensities) }
+        , m_segmentationData{ AZStd::move(other.m_segmentationData) }
     {
         other.m_count = 0U;
     }
@@ -32,6 +34,7 @@ namespace ROS2
         ClearFieldIfPresent<RaycastResultFlags::Point>();
         ClearFieldIfPresent<RaycastResultFlags::Range>();
         ClearFieldIfPresent<RaycastResultFlags::Intensity>();
+        ClearFieldIfPresent<RaycastResultFlags::SegmentationData>();
     }
 
     void RaycastResults::Resize(size_t count)
@@ -40,6 +43,7 @@ namespace ROS2
         ResizeFieldIfPresent<RaycastResultFlags::Point>(count);
         ResizeFieldIfPresent<RaycastResultFlags::Range>(count);
         ResizeFieldIfPresent<RaycastResultFlags::Intensity>(count);
+        ResizeFieldIfPresent<RaycastResultFlags::SegmentationData>(count);
     }
 
     RaycastResults& RaycastResults::operator=(RaycastResults&& other)
@@ -55,6 +59,7 @@ namespace ROS2
         m_points = AZStd::move(other.m_points);
         m_ranges = AZStd::move(other.m_ranges);
         m_intensities = AZStd::move(other.m_intensities);
+        m_segmentationData = AZStd::move(other.m_segmentationData);
 
         return *this;
     }

+ 45 - 0
Gems/ROS2/Code/Source/Lidar/SegmentationClassConfiguration.cpp

@@ -0,0 +1,45 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+#include <AzCore/Serialization/EditContext.h>
+#include <AzCore/Serialization/EditContextConstants.inl>
+#include <ROS2/Lidar/ClassSegmentationBus.h>
+#include <ROS2/Lidar/SegmentationClassConfiguration.h>
+
+namespace ROS2
+{
+    const SegmentationClassConfiguration SegmentationClassConfiguration::UnknownClass =
+        SegmentationClassConfiguration{ "Unknown", UnknownClassId, AZ::Colors::White };
+    const SegmentationClassConfiguration SegmentationClassConfiguration::GroundClass =
+        SegmentationClassConfiguration{ "Ground", TerrainClassId, AZ::Colors::Brown };
+
+    void SegmentationClassConfiguration::Reflect(AZ::ReflectContext* context)
+    {
+        if (auto serializeContext = azrtti_cast<AZ::SerializeContext*>(context))
+        {
+            serializeContext->Class<SegmentationClassConfiguration>()
+                ->Version(0)
+                ->Field("className", &SegmentationClassConfiguration::m_className)
+                ->Field("classId", &SegmentationClassConfiguration::m_classId)
+                ->Field("classColor", &SegmentationClassConfiguration::m_classColor);
+
+            if (AZ::EditContext* ec = serializeContext->GetEditContext())
+            {
+                ec->Class<SegmentationClassConfiguration>(
+                      "Lidar Segmentation Class Configuration", "Lidar Segmentation Class configuration")
+                    ->DataElement(
+                        AZ::Edit::UIHandlers::Default, &SegmentationClassConfiguration::m_className, "Class Name", "Name of the class")
+                    ->Attribute(AZ::Edit::Attributes::ContainerCanBeModified, true)
+                    ->DataElement(AZ::Edit::UIHandlers::Default, &SegmentationClassConfiguration::m_classId, "Class Id", "Id of the class")
+                    ->Attribute(AZ::Edit::Attributes::ContainerCanBeModified, true)
+                    ->DataElement(
+                        AZ::Edit::UIHandlers::Default, &SegmentationClassConfiguration::m_classColor, "Class Color", "Color of the class")
+                    ->Attribute(AZ::Edit::Attributes::ContainerCanBeModified, true);
+            }
+        }
+    }
+} // namespace ROS2

+ 58 - 0
Gems/ROS2/Code/Source/Lidar/SegmentationUtils.cpp

@@ -0,0 +1,58 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+#include <ROS2/Lidar/ClassSegmentationBus.h>
+#include <ROS2/Lidar/SegmentationUtils.h>
+
+namespace ROS2::SegmentationUtils
+{
+    uint8_t FetchClassIdForEntity(AZ::EntityId entityId)
+    {
+        AZStd::optional<uint8_t> classId;
+
+        LmbrCentral::Tags entityTags;
+        LmbrCentral::TagComponentRequestBus::EventResult(entityTags, entityId, &LmbrCentral::TagComponentRequests::GetTags);
+        auto* segmentationInterface = ClassSegmentationInterface::Get();
+        if (!segmentationInterface)
+        {
+            return UnknownClassId;
+        }
+
+        for (const auto& tag : entityTags)
+        {
+            AZStd::optional<uint8_t> tagClassId = segmentationInterface->GetClassIdForTag(tag);
+            if (tagClassId.has_value())
+            {
+                if (classId.has_value())
+                {
+                    AZ_Warning(
+                        "EntityManager",
+                        false,
+                        "Entity with ID: %s has more than one class tag. Assigning first class ID %u",
+                        entityId.ToString().c_str(),
+                        classId.value());
+                }
+                else
+                {
+                    classId = tagClassId.value();
+                }
+            }
+        }
+
+        if (!classId.has_value())
+        {
+            AZ_Warning(
+                "EntityManager",
+                false,
+                "Entity with ID: %s has no class tag. Assigning unknown class ID: %u",
+                entityId.ToString().c_str(),
+                UnknownClassId);
+        }
+
+        return classId.value_or(UnknownClassId);
+    }
+} // namespace ROS2::SegmentationUtils

+ 2 - 0
Gems/ROS2/Code/Source/ROS2ModuleInterface.h

@@ -19,6 +19,7 @@
 #include <Gripper/GripperActionServerComponent.h>
 #include <Gripper/VacuumGripperComponent.h>
 #include <Imu/ROS2ImuSensorComponent.h>
+#include <Lidar/ClassSegmentationConfigurationComponent.h>
 #include <Lidar/LidarRegistrarSystemComponent.h>
 #include <Lidar/ROS2Lidar2DSensorComponent.h>
 #include <Lidar/ROS2LidarSensorComponent.h>
@@ -99,6 +100,7 @@ namespace ROS2
                     ROS2ContactSensorComponent::CreateDescriptor(),
                     FollowingCameraComponent::CreateDescriptor(),
                     GeoReferenceLevelComponent::CreateDescriptor(),
+                    ClassSegmentationConfigurationComponent::CreateDescriptor(),
                 });
         }
 

+ 4 - 0
Gems/ROS2/Code/ros2_files.cmake

@@ -79,6 +79,10 @@ set(FILES
         Source/Lidar/ROS2Lidar2DSensorComponent.h
         Source/Lidar/ROS2LidarSensorComponent.cpp
         Source/Lidar/ROS2LidarSensorComponent.h
+        Source/Lidar/ClassSegmentationConfigurationComponent.cpp
+        Source/Lidar/ClassSegmentationConfigurationComponent.h
+        Source/Lidar/SegmentationClassConfiguration.cpp
+        Source/Lidar/SegmentationUtils.cpp
         Source/Manipulation/Controllers/JointsArticulationControllerComponent.cpp
         Source/Manipulation/Controllers/JointsArticulationControllerComponent.h
         Source/Manipulation/Controllers/JointsPIDControllerComponent.cpp

+ 4 - 1
Gems/ROS2/Code/ros2_header_files.cmake

@@ -37,10 +37,13 @@ set(FILES
         Include/ROS2/RobotImporter/SDFormatModelPluginImporterHook.h
         Include/ROS2/RobotImporter/SDFormatSensorImporterHook.h
         Include/ROS2/ROS2SensorTypesIds.h
+        Include/ROS2/Lidar/ClassSegmentationBus.h
         Include/ROS2/Lidar/LidarRaycasterBus.h
         Include/ROS2/Lidar/LidarSystemBus.h
         Include/ROS2/Lidar/LidarRegistrarBus.h
         Include/ROS2/Lidar/RaycastResults.h
+        Include/ROS2/Lidar/SegmentationClassConfiguration.h
+        Include/ROS2/Lidar/SegmentationUtils.h
         Include/ROS2/ROS2Bus.h
         Include/ROS2/ROS2GemUtilities.h
         Include/ROS2/Sensor/Events/EventSourceAdapter.h
@@ -56,4 +59,4 @@ set(FILES
         Include/ROS2/Utilities/ROS2Conversions.h
         Include/ROS2/Utilities/ROS2Names.h
         Include/ROS2/VehicleDynamics/VehicleInputControlBus.h
-        )
+)