|
@@ -146,6 +146,21 @@ namespace MachineLearning
|
|
|
//trainingInstance->m_layerBiases[layerIndex].Draw(ImGui::GetColumnWidth(), 200.0f);
|
|
|
}
|
|
|
|
|
|
+ void DrawDataPanel(TrainingDataView& data, AZStd::string& dataName, AZStd::string& labelName)
|
|
|
+ {
|
|
|
+ ImGui::PushID(&data);
|
|
|
+ int32_t firstElement = static_cast<int32_t>(data.m_first);
|
|
|
+ int32_t span = static_cast<int32_t>(data.m_last) - firstElement;
|
|
|
+ TextInputHelper("Asset file", dataName);
|
|
|
+ TextInputHelper("Label file", labelName);
|
|
|
+ ImGui::SliderInt("First", &firstElement, 0, static_cast<int32_t>(data.GetOriginalSize()));
|
|
|
+ ImGui::SameLine();
|
|
|
+ ImGui::SliderInt("Count", &span, 0, static_cast<int32_t>(data.GetOriginalSize() - firstElement));
|
|
|
+ data.m_first = firstElement;
|
|
|
+ data.m_last = firstElement + span;
|
|
|
+ ImGui::PopID();
|
|
|
+ }
|
|
|
+
|
|
|
void MachineLearningDebugTrainingWindow::OnImGuiUpdate()
|
|
|
{
|
|
|
const float TEXT_BASE_WIDTH = ImGui::CalcTextSize("A").x;
|
|
@@ -265,28 +280,52 @@ namespace MachineLearning
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- ImGui::Text("Model Name: %s", m_selectedModel->GetName().c_str());
|
|
|
ImGui::NewLine();
|
|
|
+ ImGui::Text("Model Name: %s", m_selectedModel->GetName().c_str());
|
|
|
ImGui::Text("Asset location: %s", m_selectedModel->GetAssetFile(AssetTypes::Model).c_str());
|
|
|
- ImGui::NewLine();
|
|
|
-
|
|
|
- ImGui::Text("Total samples: %d", trainingInstance->m_totalSamples);
|
|
|
- ImGui::Text("Correct predictions: %d", trainingInstance->m_correctPredictions);
|
|
|
- ImGui::Text("Incorrect predictions: %d", trainingInstance->m_incorrectPredictions);
|
|
|
|
|
|
- const float accuracy = (static_cast<float>(trainingInstance->m_correctPredictions) * 100.0f) / static_cast<float>(trainingInstance->m_totalSamples);
|
|
|
- ImGui::Text("Accuracy: %f", accuracy);
|
|
|
+ if (ImGui::BeginTable("Accuracy", 2, flags))
|
|
|
+ {
|
|
|
+ ImGui::TableSetupColumn("Name", ImGuiTableColumnFlags_WidthFixed, TEXT_BASE_WIDTH * 32.0f);
|
|
|
+ ImGui::TableSetupColumn("Value", ImGuiTableColumnFlags_WidthStretch);
|
|
|
+ ImGui::TableHeadersRow();
|
|
|
+ ImGui::TableNextRow();
|
|
|
+ ImGui::TableNextColumn();
|
|
|
+ ImGui::Text("Total samples");
|
|
|
+ ImGui::TableNextColumn();
|
|
|
+ ImGui::Text("%d", trainingInstance->m_totalSamples);
|
|
|
+ ImGui::TableNextRow();
|
|
|
+ ImGui::TableNextColumn();
|
|
|
+ ImGui::Text("Correct predictions");
|
|
|
+ ImGui::TableNextColumn();
|
|
|
+ ImGui::Text("%d", trainingInstance->m_correctPredictions);
|
|
|
+ ImGui::TableNextRow();
|
|
|
+ ImGui::TableNextColumn();
|
|
|
+ ImGui::Text("Incorrect predictions");
|
|
|
+ ImGui::TableNextColumn();
|
|
|
+ ImGui::Text("%d", trainingInstance->m_incorrectPredictions);
|
|
|
+ ImGui::TableNextRow();
|
|
|
+ ImGui::TableNextColumn();
|
|
|
+ ImGui::Text("Accuracy");
|
|
|
+ ImGui::TableNextColumn();
|
|
|
+ const float accuracy = (static_cast<float>(trainingInstance->m_correctPredictions) * 100.0f) / static_cast<float>(trainingInstance->m_totalSamples);
|
|
|
+ ImGui::Text("%f", accuracy);
|
|
|
+ ImGui::TableNextRow();
|
|
|
+ ImGui::TableNextColumn();
|
|
|
+ ImGui::Text("Test score");
|
|
|
+ ImGui::TableNextColumn();
|
|
|
+ ImGui::Text("%f", trainingInstance->m_testHistogram.GetLastValue());
|
|
|
+ ImGui::TableNextRow();
|
|
|
+ ImGui::TableNextColumn();
|
|
|
+ ImGui::Text("Train score");
|
|
|
+ ImGui::TableNextColumn();
|
|
|
+ ImGui::Text("%f", trainingInstance->m_trainHistogram.GetLastValue());
|
|
|
+ ImGui::EndTable();
|
|
|
+ ImGui::NewLine();
|
|
|
+ }
|
|
|
|
|
|
- ImGui::Text("Test score: %f", trainingInstance->m_testHistogram.GetLastValue());
|
|
|
trainingInstance->m_testHistogram.Draw(ImGui::GetColumnWidth(), 200.0f);
|
|
|
- ImGui::Text("Train score: %f", trainingInstance->m_trainHistogram.GetLastValue());
|
|
|
trainingInstance->m_trainHistogram.Draw(ImGui::GetColumnWidth(), 200.0f);
|
|
|
- ImGui::Checkbox("Shuffle data", &trainingInstance->m_trainingCycle.m_shuffleTrainingData);
|
|
|
- ImGui::NewLine();
|
|
|
-
|
|
|
- ImGui::SliderFloat("LearningRate", &trainingInstance->m_trainingCycle.m_learningRate, 0.0f, 0.1f);
|
|
|
- ImGui::SliderFloat("LearningRateDecay", &trainingInstance->m_trainingCycle.m_learningRateDecay, 0.0f, 1.0f);
|
|
|
- ImGui::SliderFloat("EarlyStop", &trainingInstance->m_trainingCycle.m_earlyStopCost, 0.0f, 1.0f);
|
|
|
ImGui::NewLine();
|
|
|
|
|
|
ImGui::SliderInt("Batch size", &batchSize, 1, 1000);
|
|
@@ -299,37 +338,23 @@ namespace MachineLearning
|
|
|
trainingInstance->m_trainingCycle.m_costFunction = static_cast<LossFunctions>(costMetric);
|
|
|
ImGui::NewLine();
|
|
|
|
|
|
+ ImGui::SliderFloat("LearningRate", &trainingInstance->m_trainingCycle.m_learningRate, 0.0f, 0.1f);
|
|
|
+ ImGui::SliderFloat("LearningRateDecay", &trainingInstance->m_trainingCycle.m_learningRateDecay, 0.0f, 1.0f);
|
|
|
+ ImGui::SliderFloat("EarlyStop", &trainingInstance->m_trainingCycle.m_earlyStopCost, 0.0f, 1.0f);
|
|
|
+ ImGui::NewLine();
|
|
|
+
|
|
|
ImGui::PushItemWidth(ImGui::GetWindowWidth() * 0.4f);
|
|
|
|
|
|
+ ImGui::Checkbox("Shuffle data", &trainingInstance->m_trainingCycle.m_shuffleTrainingData);
|
|
|
if (ImGui::CollapsingHeader("Test data", ImGuiTreeNodeFlags_Framed))
|
|
|
{
|
|
|
- ImGui::PushID(&trainingInstance->m_trainingCycle.m_testData);
|
|
|
- int32_t firstElement = static_cast<int32_t>(trainingInstance->m_trainingCycle.m_testData.m_first);
|
|
|
- int32_t lastElement = static_cast<int32_t>(trainingInstance->m_trainingCycle.m_testData.m_last);
|
|
|
- TextInputHelper("Asset file", trainingInstance->m_testDataName);
|
|
|
- TextInputHelper("Label file", trainingInstance->m_testLabelName);
|
|
|
- ImGui::SliderInt("First", &firstElement, 0, lastElement);
|
|
|
- ImGui::SameLine();
|
|
|
- ImGui::SliderInt("Last", &lastElement, firstElement, static_cast<int32_t>(trainingInstance->m_trainingCycle.m_testData.GetOriginalSize()));
|
|
|
- trainingInstance->m_trainingCycle.m_testData.m_first = firstElement;
|
|
|
- trainingInstance->m_trainingCycle.m_testData.m_last = lastElement;
|
|
|
- ImGui::PopID();
|
|
|
+ DrawDataPanel(trainingInstance->m_trainingCycle.m_testData, trainingInstance->m_testDataName, trainingInstance->m_testLabelName);
|
|
|
}
|
|
|
ImGui::NewLine();
|
|
|
|
|
|
if (ImGui::CollapsingHeader("Train data", ImGuiTreeNodeFlags_Framed))
|
|
|
{
|
|
|
- ImGui::PushID(&trainingInstance->m_trainingCycle.m_trainData);
|
|
|
- int32_t firstElement = static_cast<int32_t>(trainingInstance->m_trainingCycle.m_trainData.m_first);
|
|
|
- int32_t lastElement = static_cast<int32_t>(trainingInstance->m_trainingCycle.m_trainData.m_last);
|
|
|
- TextInputHelper("Asset file", trainingInstance->m_trainDataName);
|
|
|
- TextInputHelper("Label file", trainingInstance->m_trainLabelName);
|
|
|
- ImGui::SliderInt("First", &firstElement, 0, lastElement);
|
|
|
- ImGui::SameLine();
|
|
|
- ImGui::SliderInt("Last", &lastElement, firstElement, static_cast<int32_t>(trainingInstance->m_trainingCycle.m_trainData.GetOriginalSize()));
|
|
|
- trainingInstance->m_trainingCycle.m_trainData.m_first = firstElement;
|
|
|
- trainingInstance->m_trainingCycle.m_trainData.m_last = lastElement;
|
|
|
- ImGui::PopID();
|
|
|
+ DrawDataPanel(trainingInstance->m_trainingCycle.m_trainData, trainingInstance->m_trainDataName, trainingInstance->m_trainLabelName);
|
|
|
}
|
|
|
|
|
|
//for (AZStd::size_t layerIter = 0; layerIter < m_selectedModel->GetLayerCount(); ++layerIter)
|