Source code for sleap.gui.training_editor
"""
Module for viewing and modifying training profiles.
"""
import attr
import cattr
from typing import Optional
from PySide2 import QtWidgets
from sleap.gui.formbuilder import YamlFormWidget
from sleap.gui.filedialog import FileDialog
from sleap.util import get_package_file
[docs]class TrainingEditor(QtWidgets.QDialog):
"""
Dialog for viewing and modifying training profiles.
Args:
profile_filename: Path to saved training profile to view.
saved_files: When user saved profile, it's path is added to this
list (which will be updated in code that created TrainingEditor).
"""
def __init__(
self,
profile_filename: Optional[str] = None,
saved_files: list = [],
skeleton: Optional["Skeleton"] = None,
*args,
**kwargs
):
super(TrainingEditor, self).__init__()
form_yaml = get_package_file("sleap/config/training_editor.yaml")
self.form_widgets = dict()
self.form_widgets["model"] = YamlFormWidget(
form_yaml, "model", title="Network Architecture"
)
self.form_widgets["datagen"] = YamlFormWidget(
form_yaml, "datagen", title="Data Generation/Preprocessing"
)
self.form_widgets["trainer"] = YamlFormWidget(
form_yaml, "trainer", title="Trainer"
)
self.form_widgets["output"] = YamlFormWidget(form_yaml, "output")
self.form_widgets["buttons"] = YamlFormWidget(form_yaml, "buttons")
self.form_widgets["spacer"] = YamlFormWidget(form_yaml, "spacer")
self.form_widgets["buttons"].mainAction.connect(self._save_as)
if hasattr(skeleton, "node_names"):
print(skeleton.node_names)
print(self.form_widgets["datagen"].form_layout.field_options_lists)
self.form_widgets["datagen"].set_field_options(
"instance_crop_ctr_node_ind", skeleton.node_names,
)
col1_layout = QtWidgets.QVBoxLayout()
col2_layout = QtWidgets.QVBoxLayout()
col1_layout.addWidget(self.form_widgets["model"])
col1_layout.addWidget(self.form_widgets["datagen"])
col1_layout.addWidget(self.form_widgets["spacer"])
col2_layout.addWidget(self.form_widgets["trainer"])
col2_layout.addWidget(self.form_widgets["output"])
col2_layout.addWidget(self.form_widgets["buttons"])
col_layout = QtWidgets.QHBoxLayout()
col_layout.addWidget(self._layout_widget(col1_layout))
col_layout.addWidget(self._layout_widget(col2_layout))
self.setLayout(col_layout)
self.profile_filename = profile_filename
self.saved_files = saved_files
@property
def profile_filename(self):
"""Returns path to currently loaded training profile."""
return self._profile_filename
@profile_filename.setter
def profile_filename(self, val):
"""Sets path to (and loads) training profile."""
self._profile_filename = val
# set window title
self.setWindowTitle(self.profile_filename)
# load file
if self.profile_filename:
self._load_profile(self.profile_filename)
@staticmethod
def _layout_widget(layout):
widget = QtWidgets.QWidget()
widget.setLayout(layout)
return widget
def _load_profile(self, profile_filename: str):
"""Loads training profile settings from file."""
from sleap.nn.model import ModelOutputType
from sleap.nn.job import TrainingJob
self.training_job = TrainingJob.load_json(profile_filename)
job_dict = cattr.unstructure(self.training_job)
job_dict["model"]["arch"] = job_dict["model"]["backbone_name"]
job_dict["model"]["output_type"] = str(self.training_job.model.output_type)
self.form_widgets["model"].set_form_data(job_dict["model"])
self.form_widgets["model"].set_form_data(job_dict["model"]["backbone"])
for name in "datagen,trainer,output".split(","):
self.form_widgets[name].set_form_data(job_dict["trainer"])
def _save_as(self):
"""Shows dialog to save training profile."""
# Show "Save" dialog
save_filename, _ = FileDialog.save(
self, caption="Save As...", dir=None, filter="Profile JSON (*.json)"
)
if len(save_filename):
from sleap.nn.model import Model, ModelOutputType
# from sleap.nn.training import TrainerConfig
from sleap.nn.job import TrainingJob, TrainerConfig
from sleap.nn import architectures
# Construct Model
model_data = self.form_widgets["model"].get_form_data()
arches = {arch.__name__: arch for arch in architectures.available_archs}
arch = arches[model_data["arch"]]
output_type = dict(
confmaps=ModelOutputType.CONFIDENCE_MAP,
pafs=ModelOutputType.PART_AFFINITY_FIELD,
centroids=ModelOutputType.CENTROIDS,
topdown_confidence_maps=ModelOutputType.TOPDOWN_CONFIDENCE_MAP,
)[model_data["output_type"]]
backbone_kwargs = {
key: val
for key, val in model_data.items()
if key in attr.fields_dict(arch).keys()
}
model = Model(output_type=output_type, backbone=arch(**backbone_kwargs))
# Construct Trainer
trainer_data = {
**self.form_widgets["datagen"].get_form_data(),
**self.form_widgets["output"].get_form_data(),
**self.form_widgets["trainer"].get_form_data(),
}
trainer_kwargs = {
key: val
for key, val in trainer_data.items()
if key in attr.fields_dict(TrainerConfig).keys()
}
trainer = TrainerConfig(**trainer_kwargs)
# Construct TrainingJob
training_job_kwargs = {
key: val
for key, val in trainer_data.items()
if key in attr.fields_dict(TrainingJob).keys()
}
training_job = TrainingJob(model, trainer, **training_job_kwargs)
# Write the file
TrainingJob.save_json(training_job, save_filename)
self.saved_files.append(save_filename)
self.profile_filename = save_filename
self.close()
if __name__ == "__main__":
import sys
profile_filename = None if len(sys.argv) <= 1 else sys.argv[1]
# profile_filename = "training_profiles/default_confmaps.json"
app = QtWidgets.QApplication([])
win = TrainingEditor(profile_filename)
win.show()
app.exec_()