"""
Module for running training and inference from the main gui application.
"""
import os
import attr
import cattr
import numpy as np
from functools import reduce
from typing import Callable, Dict, List, Optional, Tuple
from sleap.io.dataset import Labels
from sleap.io.video import Video
from sleap.gui.filedialog import FileDialog
from sleap.gui.training_editor import TrainingEditor
from sleap.gui.formbuilder import YamlFormWidget
from sleap.nn.model import ModelOutputType
from sleap.nn.job import TrainingJob
from sleap import util
from PySide2 import QtWidgets, QtCore
SELECT_FILE_OPTION = "Select training run/model file..."
MENU_NAME_TYPE_MAP = dict(
confmap=(ModelOutputType.CONFIDENCE_MAP, ModelOutputType.TOPDOWN_CONFIDENCE_MAP,),
paf=(ModelOutputType.PART_AFFINITY_FIELD,),
centroid=(ModelOutputType.CENTROIDS,),
)
[docs]class InferenceDialog(QtWidgets.QDialog):
"""Training/inference dialog.
The dialog can be used in different modes:
* simplified training + inference (fewer controls)
* expert training + inference (full controls)
* inference only
Arguments:
labels_filename: Path to the dataset where we'll get training data.
labels: The dataset where we'll get training data and add predictions.
mode: String which specified mode ("learning", "expert", or "inference").
"""
learningFinished = QtCore.Signal(int)
def __init__(
self,
labels_filename: str,
labels: Labels,
mode: str = "expert",
*args,
**kwargs,
):
super(InferenceDialog, self).__init__(*args, **kwargs)
self.labels_filename = labels_filename
self.labels = labels
self.mode = mode
self._frame_selection = None
self._job_filter = None
if self.mode == "inference":
self._job_filter = lambda job: job.is_trained
print(f"Number of frames to train on: {len(labels.user_labeled_frames)}")
title = dict(
learning="Training and Inference",
inference="Inference",
expert="Inference Pipeline",
)
self.form_widget = YamlFormWidget.from_name(
form_name="inference_forms",
which_form=self.mode,
title=title[self.mode] + " Settings",
)
self.setWindowTitle(title[self.mode])
# form ui
is_confmap_strict = self.mode == "learning"
job_option_widgets = dict()
if "_conf_job" in self.form_widget.fields:
job_option_widgets["confmap"] = self.form_widget.fields["_conf_job"]
if "_paf_job" in self.form_widget.fields:
job_option_widgets["paf"] = self.form_widget.fields["_paf_job"]
if "_centroid_job" in self.form_widget.fields:
job_option_widgets["centroid"] = self.form_widget.fields["_centroid_job"]
self.job_menu_manager = JobMenuManager(
labels_filename,
job_option_widgets,
require_trained=(self.mode == "inference"),
strict_confmap_type=is_confmap_strict,
menu_selection_callback=self.on_job_menu_selection,
)
self.job_menu_manager.rebuild()
self.job_menu_manager.update_menus(init=True)
buttons = QtWidgets.QDialogButtonBox()
self.cancel_button = buttons.addButton(QtWidgets.QDialogButtonBox.Cancel)
self.run_button = buttons.addButton(
"Run " + title[self.mode], QtWidgets.QDialogButtonBox.AcceptRole
)
self.status_message = QtWidgets.QLabel("hi!")
buttons_layout = QtWidgets.QHBoxLayout()
buttons_layout.addWidget(self.status_message)
buttons_layout.addWidget(buttons, alignment=QtCore.Qt.AlignTop)
buttons_layout_widget = QtWidgets.QWidget()
buttons_layout_widget.setLayout(buttons_layout)
layout = QtWidgets.QVBoxLayout()
layout.addWidget(self.form_widget)
layout.addWidget(buttons_layout_widget)
self.setLayout(layout)
# connect actions to buttons
# TODO: fix
def edit_conf_profile():
self._view_profile(self.form_widget["_conf_job"], menu_name="confmap")
def edit_paf_profile():
self._view_profile(
self.form_widget["_paf_job"], menu_name="paf",
)
def edit_cent_profile():
self._view_profile(self.form_widget["_centroid_job"], menu_name="centroid")
if "_view_conf" in self.form_widget.buttons:
self.form_widget.buttons["_view_conf"].clicked.connect(edit_conf_profile)
if "_view_paf" in self.form_widget.buttons:
self.form_widget.buttons["_view_paf"].clicked.connect(edit_paf_profile)
if "_view_centoids" in self.form_widget.buttons:
self.form_widget.buttons["_view_centoids"].clicked.connect(
edit_cent_profile
)
if "_view_datagen" in self.form_widget.buttons:
self.form_widget.buttons["_view_datagen"].clicked.connect(self.view_datagen)
self.form_widget.valueChanged.connect(lambda: self.update_gui())
buttons.accepted.connect(self.run)
buttons.rejected.connect(self.reject)
self.update_gui()
@property
def frame_selection(self) -> Dict[str, Dict[Video, List[int]]]:
"""
Returns dictionary with frames that user has selected for inference.
"""
return self._frame_selection
@frame_selection.setter
def frame_selection(self, frame_selection: Dict[str, Dict[Video, List[int]]]):
"""Sets options of frames on which to run inference."""
self._frame_selection = frame_selection
if "_predict_frames" in self.form_widget.fields.keys():
prediction_options = []
def count_total_frames(videos_frames):
if not videos_frames:
return 0
count = 0
for frame_list in videos_frames.values():
# Check for [x, Y] range given as X, -Y
# (we're not using typical [X, Y)-style range here)
if len(frame_list) == 2 and frame_list[1] < 0:
count += -frame_list[1] - frame_list[0]
elif frame_list != (0, 0):
count += len(frame_list)
return count
# Determine which options are available given _frame_selection
total_random = count_total_frames(self._frame_selection["random"])
total_suggestions = count_total_frames(self._frame_selection["suggestions"])
clip_length = count_total_frames(self._frame_selection["clip"])
video_length = count_total_frames(self._frame_selection["video"])
# Build list of options
if self.mode != "inference":
prediction_options.append("nothing")
prediction_options.append("current frame")
option = f"random frames ({total_random} total frames)"
prediction_options.append(option)
default_option = option
if total_suggestions > 0:
option = f"suggested frames ({total_suggestions} total frames)"
prediction_options.append(option)
default_option = option
if clip_length > 0:
option = f"selected clip ({clip_length} frames)"
prediction_options.append(option)
default_option = option
prediction_options.append(f"entire video ({video_length} frames)")
self.form_widget.fields["_predict_frames"].set_options(
prediction_options, default_option
)
[docs] def show(self):
"""Shows dialog (we hide rather than close to maintain settings)."""
super(InferenceDialog, self).show()
# TODO: keep selection and any items added from training editor
self.job_menu_manager.rebuild()
self.job_menu_manager.update_menus()
[docs] def update_gui(self):
"""Updates gui state after user changes to options."""
form_data = self.form_widget.get_form_data()
can_run = True
use_centroids = form_data.get("_use_centroids", False)
if "_use_centroids" in self.form_widget.fields:
if form_data.get("_use_trained_centroids", False):
# you must use centroids if you are using a centroid model
use_centroids = True
self.form_widget.set_form_data(dict(_use_centroids=True))
self.form_widget.fields["_use_centroids"].setEnabled(False)
else:
self.form_widget.fields["_use_centroids"].setEnabled(True)
if use_centroids:
# you must crop if you are using centroids
self.form_widget.set_form_data(dict(instance_crop=True))
self.form_widget.fields["instance_crop"].setEnabled(False)
else:
self.form_widget.fields["instance_crop"].setEnabled(True)
error_messages = []
if form_data.get("_use_trained_confmaps", False) and form_data.get(
"_use_trained_pafs", False
):
# make sure trained models are compatible
conf_job, _ = self.job_menu_manager.get_current_job("confmap")
paf_job, _ = self.job_menu_manager.get_current_job("paf")
# only check compatible if we have both profiles
if conf_job is not None and paf_job is not None:
if conf_job.trainer.scale != paf_job.trainer.scale:
can_run = False
error_messages.append(
f"training image scale for confmaps ({conf_job.trainer.scale}) does not match pafs ({paf_job.trainer.scale})"
)
if conf_job.trainer.instance_crop != paf_job.trainer.instance_crop:
can_run = False
crop_model_name = (
"confmaps" if conf_job.trainer.instance_crop else "pafs"
)
error_messages.append(
f"exactly one model ({crop_model_name}) was trained on crops"
)
if use_centroids and not conf_job.trainer.instance_crop:
can_run = False
error_messages.append(
f"models used with centroids must be trained on cropped images"
)
message = ""
if not can_run:
message = (
"Unable to run with selected models:\n- "
+ ";\n- ".join(error_messages)
+ "."
)
self.status_message.setText(message)
self.run_button.setEnabled(can_run)
def _get_model_types_to_use(self):
"""Returns lists of model types which user has enabled."""
form_data = self.form_widget.get_form_data()
types_to_use = []
# TODO: check _grouping_method for confmaps vs topdown
# always include confidence maps
if "topdown" in form_data.get("_grouping_method", ""):
types_to_use.append("topdown")
else:
types_to_use.append("confmap")
# by default we want to use part affinity fields
do_use_pafs = form_data.get("_use_pafs", True)
if form_data.get("_dont_use_pafs", False):
do_use_pafs = False
elif form_data.get("_multi_instance_mode", "") == "single":
do_use_pafs = False
elif "topdown" in form_data.get("_grouping_method", ""):
do_use_pafs = False
if do_use_pafs:
types_to_use.append("paf")
# by default we want to use centroids
do_use_centroids = True
if not form_data.get("_use_centroids", True):
do_use_centroids = False
elif form_data.get("_region_proposal_mode", "") == "full frame":
do_use_centroids = False
if do_use_centroids:
types_to_use.append("centroid")
return types_to_use
def _get_current_training_jobs(self) -> Dict[ModelOutputType, TrainingJob]:
"""
Returns all currently selected training jobs.
Form fields which match job parameters override values in saved jobs.
"""
form_data = self.form_widget.get_form_data()
training_jobs = dict()
default_use_trained = self.mode == "inference"
for menu_name in self._get_model_types_to_use():
job, _ = self.job_menu_manager.get_current_job(menu_name)
if job is None:
continue
model_type = job.model.output_type
if model_type != ModelOutputType.CENTROIDS:
# update training job from params in form
trainer = job.trainer
for key, val in form_data.items():
# check if field name is [var]_[model_type] (eg sigma_confmaps)
if key.split("_")[-1] == str(model_type):
key = "_".join(key.split("_")[:-1])
# check if form field matches attribute of Trainer object
if key in dir(trainer):
setattr(trainer, key, val)
# Use already trained model if desired
if form_data.get(f"_use_trained_{str(model_type)}", default_use_trained):
job.use_trained_model = True
elif model_type == ModelOutputType.TOPDOWN_CONFIDENCE_MAP:
if form_data.get(f"_use_trained_confmaps", default_use_trained):
job.use_trained_model = True
# Clear parameters that shouldn't be copied
job.val_set_filename = None
job.test_set_filename = None
training_jobs[model_type] = job
return training_jobs
[docs] def run(self):
"""Run training (or inference) with current dialog settings."""
# Collect TrainingJobs and params from form
form_data = self.form_widget.get_form_data()
training_jobs = self._get_current_training_jobs()
# Close the dialog now that we have the data from it
self.accept()
frames_to_predict = dict()
if self._frame_selection is not None:
predict_frames_choice = form_data.get("_predict_frames", "")
if predict_frames_choice.startswith("current frame"):
frames_to_predict = self._frame_selection["frame"]
elif predict_frames_choice.startswith("random"):
frames_to_predict = self._frame_selection["random"]
elif predict_frames_choice.startswith("selected clip"):
frames_to_predict = self._frame_selection["clip"]
elif predict_frames_choice.startswith("suggested"):
frames_to_predict = self._frame_selection["suggestions"]
elif predict_frames_choice.startswith("entire video"):
frames_to_predict = self._frame_selection["video"]
# Convert [X, Y+1) ranges to [X, Y] ranges for inference cli
for video, frame_list in frames_to_predict.items():
# Check for [A, -B] list representing [A, B) range
if len(frame_list) == 2 and frame_list[1] < 0:
frame_list = (frame_list[0], frame_list[1] + 1)
frames_to_predict[video] = frame_list
# for key, val in training_jobs.items():
# print(key)
# print(val)
# print()
# print(form_data)
# Run training/inference pipeline using the TrainingJobs
new_counts = run_learning_pipeline(
labels_filename=self.labels_filename,
labels=self.labels,
training_jobs=training_jobs,
inference_params=form_data,
frames_to_predict=frames_to_predict,
)
self.learningFinished.emit(new_counts)
if new_counts >= 0:
QtWidgets.QMessageBox(
text=f"Inference has finished. Instances were predicted on {new_counts} frames."
).exec_()
[docs] def view_datagen(self):
"""Shows windows with sample visual data that will be used training."""
from sleap.nn import data
from sleap.io.video import Video
from sleap.gui.overlays.confmaps import demo_confmaps
from sleap.gui.overlays.pafs import demo_pafs
training_data = data.TrainingData.from_labels(self.labels)
ds = training_data.to_ds()
conf_job, _ = self.job_menu_manager.get_current_job("confmap")
# settings for datagen
form_data = self.form_widget.get_form_data()
scale = form_data.get("scale", conf_job.trainer.scale)
sigma = form_data.get("sigma", None)
sigma_confmaps = form_data.get("sigma_confmaps", sigma)
sigma_pafs = form_data.get("sigma_pafs", sigma)
instance_crop = form_data.get("instance_crop", conf_job.trainer.instance_crop)
bounding_box_size = form_data.get(
"bounding_box_size", conf_job.trainer.bounding_box_size
)
# negative_samples = form_data.get("negative_samples", 0)
# Augment dataset
aug_params = dict(
# rotate=conf_job.trainer.augment_rotate,
# rotation_min_angle=-conf_job.trainer.augment_rotation,
# rotation_max_angle=conf_job.trainer.augment_rotation,
scale=form_data.get("scale", conf_job.trainer.scale),
# scale_min=conf_job.trainer.augment_scale_min,
# scale_max=conf_job.trainer.augment_scale_max,
# uniform_noise=conf_job.trainer.augment_uniform_noise,
# min_noise_val=conf_job.trainer.augment_uniform_noise_min_val,
# max_noise_val=conf_job.trainer.augment_uniform_noise_max_val,
# gaussian_noise=conf_job.trainer.augment_gaussian_noise,
# gaussian_noise_mean=conf_job.trainer.augment_gaussian_noise_mean,
# gaussian_noise_stddev=conf_job.trainer.augment_gaussian_noise_stddev,
contrast=conf_job.trainer.augment_contrast,
contrast_min_gamma=conf_job.trainer.augment_contrast_min_gamma,
contrast_max_gamma=conf_job.trainer.augment_contrast_max_gamma,
brightness=conf_job.trainer.augment_brightness,
brightness_val=conf_job.trainer.augment_brightness_val,
)
ds = data.augment_dataset(ds, **aug_params)
if bounding_box_size is None or bounding_box_size <= 0:
bounding_box_size = data.estimate_instance_crop_size(
training_data.points,
min_multiple=conf_job.model.input_min_multiple,
padding=conf_job.trainer.instance_crop_padding,
)
if instance_crop:
ds = data.instance_crop_dataset(
ds, box_height=bounding_box_size, box_width=bounding_box_size
)
skeleton = self.labels.skeletons[0]
if conf_job.model.output_type == ModelOutputType.CONFIDENCE_MAP:
conf_data = data.make_confmap_dataset(
ds, output_scale=scale, sigma=sigma_confmaps,
)
elif conf_job.model.output_type == ModelOutputType.TOPDOWN_CONFIDENCE_MAP:
conf_data = data.make_instance_confmap_dataset(
ds, with_ctr_peaks=True, output_scale=scale, sigma=sigma_confmaps,
)
imgs = []
confmaps = []
for img, confmap in conf_data.take(10):
if type(confmap) == tuple:
confmap = confmap[0]
imgs.append(img)
confmaps.append(confmap)
imgs = np.stack(imgs)
confmaps = np.stack(confmaps)
conf_vid = Video.from_numpy(imgs * 255)
conf_win = demo_confmaps(confmaps, conf_vid)
conf_win.activateWindow()
conf_win.resize(bounding_box_size + 50, bounding_box_size + 50)
conf_win.move(200, 200)
if ModelOutputType.PART_AFFINITY_FIELD in self._get_current_training_jobs():
paf_data = data.make_paf_dataset(
ds,
data.SimpleSkeleton.from_skeleton(skeleton).edges,
output_scale=scale,
distance_threshold=sigma_pafs,
)
imgs = []
pafs = []
for img, paf in paf_data.take(10):
imgs.append(img)
pafs.append(paf)
imgs = np.stack(imgs)
pafs = np.stack(pafs)
paf_vid = Video.from_numpy(imgs * 255)
paf_win = demo_pafs(pafs, paf_vid)
paf_win.activateWindow()
paf_win.resize(bounding_box_size + 50, bounding_box_size + 50)
paf_win.move(220 + conf_win.rect().width(), 200)
# FIXME: hide dialog so use can see other windows
# can we show these windows without closing dialog?
self.hide()
def _view_profile(self, filename: str, menu_name: str, windows=[]):
"""Opens profile editor in new dialog window."""
saved_files = []
win = TrainingEditor(
filename,
saved_files=saved_files,
skeleton=self.labels.skeletons[0],
parent=self,
)
windows.append(win)
win.exec_()
for new_filename in saved_files:
self.job_menu_manager.add_job_to_list(new_filename, menu_name)
def update_fields_from_job(self, job: TrainingJob):
model_type = job.model.output_type
training_params = cattr.unstructure(job.trainer)
training_params_specific = {
f"{key}_{str(model_type)}": val for key, val in training_params.items()
}
# confmap and paf models should share some params shown in dialog (e.g. scale)
# but centroids does not, so just set any centroid_foo fields from its profile
if model_type in [ModelOutputType.CENTROIDS]:
training_params = training_params_specific
else:
training_params = {**training_params, **training_params_specific}
self.form_widget.set_form_data(training_params)
# is the model already trained?
is_trained = job.is_trained
field_name = f"_use_trained_{str(model_type)}"
# update "use trained" checkbox if present
if field_name in self.form_widget.fields:
self.form_widget.fields[field_name].setEnabled(is_trained)
self.form_widget[field_name] = is_trained
@attr.s(auto_attribs=True)
class JobMenuManager:
labels_filename: str
job_option_widgets: dict # keyed by menu name
job_options_by_menu: dict = attr.ib(factory=dict) # keyed by model type
strict_confmap_type: bool = False
require_trained: bool = False
menu_selection_callback: Optional[Callable] = None
def rebuild(self):
"""
Rebuilds list of profile options (checking for new profile files).
"""
# load list of job profiles from directory
profile_dir = util.get_package_file("sleap/training_profiles")
self.job_options_by_menu = dict()
# list any profiles from previous runs
if self.labels_filename:
models_dir = os.path.join(os.path.dirname(self.labels_filename), "models")
if os.path.exists(models_dir):
self.find_saved_jobs(models_dir, self.job_options_by_menu)
# list default profiles (without searching subdirs)
self.find_saved_jobs(profile_dir, self.job_options_by_menu, depth=0)
# Apply any filters
if self.require_trained:
for key, jobs_list in self.job_options_by_menu.items():
self.job_options_by_menu[key] = [
(path, job) for (path, job) in jobs_list if job.is_trained
]
def get_menu_options(self, menu_name: str):
"""Returns the list of (path, TrainingJob) tuples for menu."""
if menu_name in self.job_options_by_menu:
return self.job_options_by_menu[menu_name]
else:
return []
# menu_options = []
# for model_type in MENU_NAME_TYPE_MAP[menu_name]:
# if model_type in self.job_options_by_model_type:
# menu_options.extend(self.job_options_by_model_type[model_type])
# return menu_options
def option_list_from_jobs_list(self, jobs):
"""Returns list of menu options for given model type."""
option_list = [name for (name, job) in jobs]
option_list.append("")
option_list.append("---")
option_list.append(SELECT_FILE_OPTION)
return option_list
def update_menus(self, init: bool = False):
"""Updates the menus with training profile options.
Args:
init: Whether this is first time calling (so we should connect
signals), or we're just updating menus.
Returns:
None.
"""
for menu_name in self.job_option_widgets.keys():
self.update_menu(menu_name, init=init)
def update_menu(
self,
menu_name,
select_item: Optional[str] = None,
init: bool = False,
signal: bool = False,
):
menu_options = self.get_menu_options(menu_name)
field = self.job_option_widgets[menu_name]
if init:
def menu_action(idx, menu=menu_name, field=field):
self.menu_selection_callback(menu, idx, field)
field.currentIndexChanged.connect(menu_action)
elif not signal:
# block signals so we can update combobox without overwriting
# any user data with the defaults from the profile
field.blockSignals(True)
field.set_options(self.option_list_from_jobs_list(menu_options), select_item)
# enable signals again so that choice of profile will update params
field.blockSignals(False)
def get_current_job(
self, menu_name: str
) -> Tuple[Optional[TrainingJob], Optional[str]]:
"""Returns training job currently selected for given model type.
Args:
model_type: The type of model for which we want data.
Returns: Tuple of (TrainingJob, path to job profile).
"""
# by default use the first model for a given type
idx = 0
# If there's a menu, then use the selected item
if menu_name in self.job_option_widgets:
field = self.job_option_widgets[menu_name]
idx = field.currentIndex()
job_filename, job = self.get_menu_item(menu_name, idx)
return job, job_filename
def get_menu_item(
self, menu_name: str, item_idx: int
) -> Tuple[Optional[str], Optional[TrainingJob]]:
menu_options = self.get_menu_options(menu_name)
if item_idx >= len(menu_options):
return None, None
return menu_options[item_idx]
def insert_menu_item(self, menu_name: str, job_path, job):
# insert at beginning of list
self.job_options_by_menu[menu_name].insert(0, (job_path, job))
self.update_menu(menu_name, select_item=job_path, signal=True)
def add_job_gui(self, menu_name: str):
"""Allow user to add training profile for given model type."""
filename, _ = FileDialog.open(
None,
dir=None,
caption="Select training profile...",
filter="TrainingJob JSON (*.json)",
)
self.add_job_to_list(filename, menu_name)
# If we didn't successfully select a new file, then reset menu selection
field = self.job_option_widgets[menu_name]
if field.currentIndex() == field.count() - 1: # subtract 1 for separator
field.setCurrentIndex(-1)
def add_job_to_list(self, filename: str, menu_name: str):
"""Adds selected training profile for given model type."""
if len(filename):
try:
# try to load json as TrainingJob
job = TrainingJob.load_json(filename)
except:
# but do raise any other type of error
QtWidgets.QMessageBox(
text=f"Unable to load a training profile from {filename}."
).exec_()
raise
else:
# Get the model type for the model/profile selected by user
file_model_type = job.model.output_type
# Make sure this is the right type for this menu
if file_model_type in MENU_NAME_TYPE_MAP[menu_name]:
self.insert_menu_item(menu_name, filename, job)
else:
QtWidgets.QMessageBox(
text=f"Profile selected is for training {str(file_model_type)} instead of {menu_name}."
).exec_()
def find_saved_jobs(
self, job_dir: str, jobs=None, depth: int = 1
) -> Dict[ModelOutputType, List[Tuple[str, TrainingJob]]]:
"""Find all the TrainingJob json files in a given directory.
Args:
job_dir: the directory in which to look for json files
jobs: If given, then the found jobs will be added to this object,
rather than creating new dict.
Returns:
dict of {ModelOutputType: list of (filename, TrainingJob) tuples}
"""
json_files = util.find_files_by_suffix(job_dir, ".json", depth=depth)
# Sort files, starting with most recently modified
json_files.sort(key=lambda f: f.stat().st_mtime, reverse=True)
json_paths = [file.path for file in json_files]
jobs = dict() if jobs is None else jobs
for full_filename in json_paths:
try:
# try to load json as TrainingJob
job = TrainingJob.load_json(full_filename)
except Exception as e:
# Couldn't load as TrainingJob so just ignore this json file
# probably it's a json file for something else (or an json for a
# older version of the object with different class attributes).
print(e)
pass
else:
# we loaded the json as a TrainingJob, so see what type of model it's for
key = self.menu_name_from_model_type(job.model.output_type)
if key not in jobs:
jobs[key] = []
# See if this job file is already in list, and if so, update job
existing_job_filenames = [filename for filename, job in jobs[key]]
if full_filename in existing_job_filenames:
existing_idx = existing_job_filenames.index(full_filename)
jobs[key][existing_idx] = (full_filename, job)
else:
# It's not already in list, so add it
jobs[key].append((full_filename, job))
return jobs
def menu_name_from_model_type(self, model_type):
if self.strict_confmap_type:
conf_types = (ModelOutputType.CONFIDENCE_MAP,)
else:
conf_types = (
ModelOutputType.CONFIDENCE_MAP,
ModelOutputType.TOPDOWN_CONFIDENCE_MAP,
)
if model_type in conf_types:
return "confmap"
elif model_type == ModelOutputType.TOPDOWN_CONFIDENCE_MAP:
return "topdown"
if model_type == ModelOutputType.CENTROIDS:
return "centroid"
if model_type == ModelOutputType.PART_AFFINITY_FIELD:
return "paf"
return ""
[docs]def run_learning_pipeline(
labels_filename: str,
labels: Labels,
training_jobs: Dict["ModelOutputType", "TrainingJob"],
inference_params: Dict[str, str],
frames_to_predict: Dict[Video, List[int]] = None,
) -> int:
"""Run training (as needed) and inference.
Args:
labels_filename: Path to already saved current labels object.
labels: The current labels object; results will be added to this.
training_jobs: The TrainingJobs with params/hyperparams for training.
inference_params: Parameters to pass to inference.
frames_to_predict: Dict that gives list of frame indices for each video.
Returns:
Number of new frames added to labels.
"""
# Set the parameters specific to this run
for job in training_jobs.values():
job.labels_filename = labels_filename
# TODO: only require labels_filename if we're training?
save_viz = inference_params.get("_save_viz", False)
# Train the TrainingJobs
trained_jobs = run_gui_training(
labels_filename, training_jobs, gui=True, save_viz=save_viz
)
# Check that all the models were trained
if None in trained_jobs.values():
return -1
trained_job_paths = list(trained_jobs.values())
# Run the Predictor for suggested frames
new_labeled_frame_count = run_gui_inference(
labels=labels,
trained_job_paths=trained_job_paths,
inference_params=inference_params,
frames_to_predict=frames_to_predict,
labels_filename=labels_filename,
)
return new_labeled_frame_count
[docs]def has_jobs_to_train(training_jobs: Dict["ModelOutputType", "TrainingJob"]):
"""Returns whether any of the jobs need to be trained."""
return any(not getattr(job, "use_trained_model", False) for job in training_jobs)
[docs]def run_gui_training(
labels_filename: str,
training_jobs: Dict["ModelOutputType", "TrainingJob"],
gui: bool = True,
save_viz: bool = False,
) -> Dict["ModelOutputType", str]:
"""
Run training for each training job.
Args:
labels: Labels object from which we'll get training data.
training_jobs: Dict of the jobs to train.
gui: Whether to show gui windows and process gui events.
save_viz: Whether to save visualizations from training.
Returns:
Dict of paths to trained jobs corresponding with input training jobs.
"""
from sleap.nn import training
trained_jobs = dict()
if gui:
from sleap.nn.monitor import LossViewer
from sleap.gui.imagedir import QtImageDirectoryWidget
# open training monitor window
win = LossViewer()
win.resize(600, 400)
win.show()
for model_type, job in training_jobs.items():
if getattr(job, "use_trained_model", False):
# set path to TrainingJob already trained from previous run
# json_name = f"{job.run_name}.json"
trained_jobs[model_type] = job.run_path
print(f"Using already trained model: {trained_jobs[model_type]}")
else:
# Update save dir and run name for job we're about to train
# so we have access to them here (rather than letting
# train_subprocess update them).
training.Trainer.set_run_name(job, labels_filename)
if gui:
print("Resetting monitor window.")
win.reset(what=str(model_type))
win.setWindowTitle(f"Training Model - {str(model_type)}")
if save_viz:
viz_window = QtImageDirectoryWidget.make_training_vizualizer(
job.run_path
)
viz_window.move(win.x() + win.width() + 20, win.y())
win.on_epoch.connect(viz_window.poll)
print(f"Start training {str(model_type)}...")
def waiting():
if gui:
QtWidgets.QApplication.instance().processEvents()
# Run training
trained_job_path, success = training.Trainer.train_subprocess(
job,
labels_filename,
waiting_callback=waiting,
update_run_name=False,
save_viz=save_viz,
)
if success:
# get the path to the resulting TrainingJob file
trained_jobs[model_type] = trained_job_path
print(f"Finished training {str(model_type)}.")
else:
if gui:
win.close()
QtWidgets.QMessageBox(
text=f"An error occurred while training {str(model_type)}. Your command line terminal may have more information about the error."
).exec_()
trained_jobs[model_type] = None
if gui:
# close training monitor window
win.close()
return trained_jobs
[docs]def run_gui_inference(
labels: Labels,
trained_job_paths: List[str],
frames_to_predict: Dict[Video, List[int]],
inference_params: Dict[str, str],
labels_filename: str,
gui: bool = True,
) -> int:
"""Run inference on specified frames using models from training_jobs.
Args:
labels: The current labels object; results will be added to this.
trained_job_paths: List of paths to TrainingJobs with trained models.
frames_to_predict: Dict that gives list of frame indices for each video.
inference_params: Parameters to pass to inference.
gui: Whether to show gui windows and process gui events.
Returns:
Number of new frames added to labels.
"""
from sleap.nn import inference
if gui:
# show message while running inference
progress = QtWidgets.QProgressDialog(
f"Running inference on {len(frames_to_predict)} videos...",
"Cancel",
0,
len(frames_to_predict),
)
progress.show()
QtWidgets.QApplication.instance().processEvents()
new_lfs = []
for i, (video, frames) in enumerate(frames_to_predict.items()):
if len(frames):
def waiting():
if gui:
QtWidgets.QApplication.instance().processEvents()
progress.setValue(i)
if progress.wasCanceled():
return -1
# Run inference for desired frames in this video
predictions_path, success = inference.Predictor.predict_subprocess(
video=video,
frames=frames,
trained_job_paths=trained_job_paths,
kwargs=inference_params,
waiting_callback=waiting,
labels_filename=labels_filename,
)
if success:
predictions_labels = Labels.load_file(predictions_path, match_to=labels)
new_lfs.extend(predictions_labels.labeled_frames)
else:
if gui:
progress.close()
QtWidgets.QMessageBox(
text=f"An error occcured during inference. Your command line terminal may have more information about the error."
).exec_()
return -1
# Remove any frames without instances
new_lfs = list(filter(lambda lf: len(lf.instances), new_lfs))
# Merge predictions into current labels dataset
_, _, new_conflicts = Labels.complex_merge_between(
labels,
new_labels=Labels(new_lfs),
unify=False, # since we used match_to when loading predictions file
)
# new predictions should replace old ones
Labels.finish_complex_merge(labels, new_conflicts)
# close message window
if gui:
progress.close()
# return total_new_lf_count
return len(new_lfs)
if __name__ == "__main__":
import sys
# labels_filename = "/Volumes/fileset-mmurthy/nat/shruthi/labels-mac.json"
labels_filename = sys.argv[1]
labels = Labels.load_file(labels_filename)
app = QtWidgets.QApplication()
win = InferenceDialog(
labels=labels, labels_filename=labels_filename, mode="inference"
)
win.show()
app.exec_()
# labeled_frames = run_active_learning_pipeline(labels_filename)
# print(labeled_frames)