소스 검색

Some last tuning and debugging

Signed-off-by: kberg-amzn <[email protected]>
kberg-amzn 2 년 전
부모
커밋
85d62bd8d6

BIN
Gems/MachineLearning/Assets/Models/NumberClassifier


+ 11 - 1
Gems/MachineLearning/Code/Source/Assets/TrainingDataView.cpp

@@ -69,6 +69,10 @@ namespace MachineLearning
     {
         AZ_Assert(m_sourceData, "No datasource assigned to view");
         AZ_Assert(index + m_first < m_last, "Out of range index requested");
+        if (m_firstCache != m_first || m_lastCache != m_last)
+        {
+            FillIndicies();
+        }
         return m_sourceData->GetLabelByIndex(m_indices[index]);
     }
 
@@ -76,13 +80,19 @@ namespace MachineLearning
     {
         AZ_Assert(m_sourceData, "No datasource assigned to view");
         AZ_Assert(index + m_first < m_last, "Out of range index requested");
+        if (m_firstCache != m_first || m_lastCache != m_last)
+        {
+            FillIndicies();
+        }
         return m_sourceData->GetDataByIndex(m_indices[index]);
     }
 
     void TrainingDataView::FillIndicies()
     {
         // Generate a set of training indices that we can later optionally shuffle
-        m_indices.resize(GetOriginalSize());
+        m_indices.resize(m_last);
         std::iota(m_indices.begin(), m_indices.end(), m_first);
+        m_firstCache = m_first;
+        m_lastCache = m_last;
     }
 }

+ 2 - 0
Gems/MachineLearning/Code/Source/Assets/TrainingDataView.h

@@ -51,6 +51,8 @@ namespace MachineLearning
 
         void FillIndicies();
 
+        AZStd::size_t m_firstCache = 0;
+        AZStd::size_t m_lastCache = 0;
         AZStd::vector<AZStd::size_t> m_indices;
         ILabeledTrainingDataPtr m_sourceData;
     };

+ 63 - 38
Gems/MachineLearning/Code/Source/Debug/MachineLearningDebugTrainingWindow.cpp

@@ -146,6 +146,21 @@ namespace MachineLearning
         //trainingInstance->m_layerBiases[layerIndex].Draw(ImGui::GetColumnWidth(), 200.0f);
     }
 
+    void DrawDataPanel(TrainingDataView& data, AZStd::string& dataName, AZStd::string& labelName)
+    {
+        ImGui::PushID(&data);
+        int32_t firstElement = static_cast<int32_t>(data.m_first);
+        int32_t span = static_cast<int32_t>(data.m_last) - firstElement;
+        TextInputHelper("Asset file", dataName);
+        TextInputHelper("Label file", labelName);
+        ImGui::SliderInt("First", &firstElement, 0, static_cast<int32_t>(data.GetOriginalSize()));
+        ImGui::SameLine();
+        ImGui::SliderInt("Count", &span, 0, static_cast<int32_t>(data.GetOriginalSize() - firstElement));
+        data.m_first = firstElement;
+        data.m_last = firstElement + span;
+        ImGui::PopID();
+    }
+
     void MachineLearningDebugTrainingWindow::OnImGuiUpdate()
     {
         const float TEXT_BASE_WIDTH = ImGui::CalcTextSize("A").x;
@@ -265,28 +280,52 @@ namespace MachineLearning
                 }
             }
  
-            ImGui::Text("Model Name: %s", m_selectedModel->GetName().c_str());
             ImGui::NewLine();
+            ImGui::Text("Model Name: %s", m_selectedModel->GetName().c_str());
             ImGui::Text("Asset location: %s", m_selectedModel->GetAssetFile(AssetTypes::Model).c_str());
-            ImGui::NewLine();
-
-            ImGui::Text("Total samples: %d", trainingInstance->m_totalSamples);
-            ImGui::Text("Correct predictions: %d", trainingInstance->m_correctPredictions);
-            ImGui::Text("Incorrect predictions: %d", trainingInstance->m_incorrectPredictions);
 
-            const float accuracy = (static_cast<float>(trainingInstance->m_correctPredictions) * 100.0f) / static_cast<float>(trainingInstance->m_totalSamples);
-            ImGui::Text("Accuracy: %f", accuracy);
+            if (ImGui::BeginTable("Accuracy", 2, flags))
+            {
+                ImGui::TableSetupColumn("Name", ImGuiTableColumnFlags_WidthFixed, TEXT_BASE_WIDTH * 32.0f);
+                ImGui::TableSetupColumn("Value", ImGuiTableColumnFlags_WidthStretch);
+                ImGui::TableHeadersRow();
+                ImGui::TableNextRow();
+                ImGui::TableNextColumn();
+                ImGui::Text("Total samples");
+                ImGui::TableNextColumn();
+                ImGui::Text("%d", trainingInstance->m_totalSamples);
+                ImGui::TableNextRow();
+                ImGui::TableNextColumn();
+                ImGui::Text("Correct predictions");
+                ImGui::TableNextColumn();
+                ImGui::Text("%d", trainingInstance->m_correctPredictions);
+                ImGui::TableNextRow();
+                ImGui::TableNextColumn();
+                ImGui::Text("Incorrect predictions");
+                ImGui::TableNextColumn();
+                ImGui::Text("%d", trainingInstance->m_incorrectPredictions);
+                ImGui::TableNextRow();
+                ImGui::TableNextColumn();
+                ImGui::Text("Accuracy");
+                ImGui::TableNextColumn();
+                const float accuracy = (static_cast<float>(trainingInstance->m_correctPredictions) * 100.0f) / static_cast<float>(trainingInstance->m_totalSamples);
+                ImGui::Text("%f", accuracy);
+                ImGui::TableNextRow();
+                ImGui::TableNextColumn();
+                ImGui::Text("Test score");
+                ImGui::TableNextColumn();
+                ImGui::Text("%f", trainingInstance->m_testHistogram.GetLastValue());
+                ImGui::TableNextRow();
+                ImGui::TableNextColumn();
+                ImGui::Text("Train score");
+                ImGui::TableNextColumn();
+                ImGui::Text("%f", trainingInstance->m_trainHistogram.GetLastValue());
+                ImGui::EndTable();
+                ImGui::NewLine();
+            }
 
-            ImGui::Text("Test score: %f", trainingInstance->m_testHistogram.GetLastValue());
             trainingInstance->m_testHistogram.Draw(ImGui::GetColumnWidth(), 200.0f);
-            ImGui::Text("Train score: %f", trainingInstance->m_trainHistogram.GetLastValue());
             trainingInstance->m_trainHistogram.Draw(ImGui::GetColumnWidth(), 200.0f);
-            ImGui::Checkbox("Shuffle data", &trainingInstance->m_trainingCycle.m_shuffleTrainingData);
-            ImGui::NewLine();
-
-            ImGui::SliderFloat("LearningRate", &trainingInstance->m_trainingCycle.m_learningRate, 0.0f, 0.1f);
-            ImGui::SliderFloat("LearningRateDecay", &trainingInstance->m_trainingCycle.m_learningRateDecay, 0.0f, 1.0f);
-            ImGui::SliderFloat("EarlyStop", &trainingInstance->m_trainingCycle.m_earlyStopCost, 0.0f, 1.0f);
             ImGui::NewLine();
 
             ImGui::SliderInt("Batch size", &batchSize, 1, 1000);
@@ -299,37 +338,23 @@ namespace MachineLearning
             trainingInstance->m_trainingCycle.m_costFunction = static_cast<LossFunctions>(costMetric);
             ImGui::NewLine();
 
+            ImGui::SliderFloat("LearningRate", &trainingInstance->m_trainingCycle.m_learningRate, 0.0f, 0.1f);
+            ImGui::SliderFloat("LearningRateDecay", &trainingInstance->m_trainingCycle.m_learningRateDecay, 0.0f, 1.0f);
+            ImGui::SliderFloat("EarlyStop", &trainingInstance->m_trainingCycle.m_earlyStopCost, 0.0f, 1.0f);
+            ImGui::NewLine();
+
             ImGui::PushItemWidth(ImGui::GetWindowWidth() * 0.4f);
 
+            ImGui::Checkbox("Shuffle data", &trainingInstance->m_trainingCycle.m_shuffleTrainingData);
             if (ImGui::CollapsingHeader("Test data", ImGuiTreeNodeFlags_Framed))
             {
-                ImGui::PushID(&trainingInstance->m_trainingCycle.m_testData);
-                int32_t firstElement = static_cast<int32_t>(trainingInstance->m_trainingCycle.m_testData.m_first);
-                int32_t lastElement = static_cast<int32_t>(trainingInstance->m_trainingCycle.m_testData.m_last);
-                TextInputHelper("Asset file", trainingInstance->m_testDataName);
-                TextInputHelper("Label file", trainingInstance->m_testLabelName);
-                ImGui::SliderInt("First", &firstElement, 0, lastElement);
-                ImGui::SameLine();
-                ImGui::SliderInt("Last", &lastElement, firstElement, static_cast<int32_t>(trainingInstance->m_trainingCycle.m_testData.GetOriginalSize()));
-                trainingInstance->m_trainingCycle.m_testData.m_first = firstElement;
-                trainingInstance->m_trainingCycle.m_testData.m_last = lastElement;
-                ImGui::PopID();
+                DrawDataPanel(trainingInstance->m_trainingCycle.m_testData, trainingInstance->m_testDataName, trainingInstance->m_testLabelName);
             }
             ImGui::NewLine();
 
             if (ImGui::CollapsingHeader("Train data", ImGuiTreeNodeFlags_Framed))
             {
-                ImGui::PushID(&trainingInstance->m_trainingCycle.m_trainData);
-                int32_t firstElement = static_cast<int32_t>(trainingInstance->m_trainingCycle.m_trainData.m_first);
-                int32_t lastElement = static_cast<int32_t>(trainingInstance->m_trainingCycle.m_trainData.m_last);
-                TextInputHelper("Asset file", trainingInstance->m_trainDataName);
-                TextInputHelper("Label file", trainingInstance->m_trainLabelName);
-                ImGui::SliderInt("First", &firstElement, 0, lastElement);
-                ImGui::SameLine();
-                ImGui::SliderInt("Last", &lastElement, firstElement, static_cast<int32_t>(trainingInstance->m_trainingCycle.m_trainData.GetOriginalSize()));
-                trainingInstance->m_trainingCycle.m_trainData.m_first = firstElement;
-                trainingInstance->m_trainingCycle.m_trainData.m_last = lastElement;
-                ImGui::PopID();
+                DrawDataPanel(trainingInstance->m_trainingCycle.m_trainData, trainingInstance->m_trainDataName, trainingInstance->m_trainLabelName);
             }
 
             //for (AZStd::size_t layerIter = 0; layerIter < m_selectedModel->GetLayerCount(); ++layerIter)