Source code for sleap.gui.commands

"""
Module for gui command context and commands objects.

Each open project (i.e., `MainWindow`) will have its own `CommandContext`.
The context enables commands to access and modify the `GuiState` and `Labels`,
as well as potentially maintaining a command history (so we can add support for
undo!). See `sleap.gui.app` for how the context is created and used.

Every command will have both a method in `CommandContext` (this is what should
be used to trigger the command, e.g., connected to the menu action) and a
class which inherits from `AppCommand` (or a more specialized class such as
`NavCommand`, `GoIteratorCommand`, or `EditCommand`). Note that this code relies
on inheritance, so some care and attention is required.

A typical command will override the `ask` and `do_action` methods. If the
command updates something which affects the GUI, it should override the `topic`
attribute (this then gets passed back to the `update_callback` from the context.
If a command doesn't require any input from the user, then it doesn't need to
override the `ask` method.

If it's not possible to separate the GUI "ask" and the non-GUI "do" code, then
instead of `ask` and `do_action` you should add an `ask_and_do` method
(for instance, `DeleteDialogCommand` and `MergeProject` show dialogues which
handle both the GUI and the action). Ideally we'd endorse separation of "ask"
and "do" for all commands (this is important if we're going to implement undo)--
for now it's at least easy to see where this separation is violated.
"""

import attr
import operator
import os

from abc import ABC
from enum import Enum
from pathlib import PurePath
from typing import Callable, Dict, Iterator, List, Optional, Type, Tuple

import numpy as np

from PySide2 import QtCore, QtWidgets, QtGui

from PySide2.QtWidgets import QMessageBox

from sleap.gui.dialogs.delete import DeleteDialog
from sleap.skeleton import Skeleton
from sleap.instance import Instance, PredictedInstance, Point, Track, LabeledFrame
from sleap.io.video import Video
from sleap.io.dataset import Labels
from sleap.gui.dialogs.importvideos import ImportVideos
from sleap.gui.dialogs.filedialog import FileDialog
from sleap.gui.dialogs.missingfiles import MissingFilesDialog
from sleap.gui.dialogs.merge import MergeDialog
from sleap.gui.dialogs.message import MessageDialog
from sleap.gui.suggestions import VideoFrameSuggestions
from sleap.gui.state import GuiState


# whether we support multiple project windows (i.e., "open" opens new window)
OPEN_IN_NEW = True


[docs]class UpdateTopic(Enum): """Topics so context can tell callback what was updated by the command.""" all = 1 video = 2 skeleton = 3 labels = 4 on_frame = 5 suggestions = 6 tracks = 7 frame = 8 project = 9 project_instances = 10
[docs]class AppCommand: """Base class for specific commands. Note that this is not an abstract base class. For specific commands, you should override `ask` and/or `do_action` methods, or add an `ask_and_do` method. In many cases you'll want to override the `topics` and `does_edits` attributes. That said, these are not virtual methods/attributes and have are implemented in the base class with default behaviors (i.e., doing nothing). You should not override `execute` or `do_with_signal`. Attributes: topics: List of `UpdateTopic` items. Override this to indicate what should be updated after command is executed. does_edits: Whether command will modify data that could be saved. """ topics: List[UpdateTopic] = [] does_edits: bool = False
[docs] def execute(self, context: "CommandContext", params: dict = None): """Entry point for running command. This calls internal methods to gather information required for execution, perform the action, and notify about changes. Ideally, any information gathering should be performed in the `ask` method, and be added to the `params` dictionary which then gets passed to `do_action`. The `ask` method should not modify state. (This will make it easier to add support for undo, using an `undo_action` which will be given the same `params` dictionary.) If it's not possible to easily separate information gathering from performing the action, the child class should implement `ask_and_do`, which it turn should call `do_with_signal` to notify about changes. Args: context: This is the `CommandContext` in which the command will execute. Commands will use this to access `MainWindow`, `GuiState`, and `Labels`. params: Dictionary of any params for command. """ params = params or dict() if hasattr(self, "ask_and_do") and callable(self.ask_and_do): self.ask_and_do(context, params) else: okay = self.ask(context, params) if okay: self.do_with_signal(context, params)
[docs] @staticmethod def ask(context: "CommandContext", params: dict) -> bool: """Method for information gathering. Returns: Whether to perform action. By default returns True, but this is where we should return False if we prompt user for confirmation and they abort. """ return True
[docs] @staticmethod def do_action(context: "CommandContext", params: dict): """Method for performing action.""" pass
[docs] @classmethod def do_with_signal(cls, context: "CommandContext", params: dict): """Wrapper to perform action and notify/track changes. Don't override this method! """ cls.do_action(context, params) if cls.topics: context.signal_update(cls.topics) if cls.does_edits: context.changestack_push(cls.__name__)
[docs]@attr.s(auto_attribs=True) class FakeApp: """Use if you want to execute commands independently of the GUI app.""" labels: Labels
[docs]@attr.s(auto_attribs=True, eq=False) class CommandContext: """ Context within in which commands are executed. When you create a new command, you should both create a class for the command (which inherits from `CommandClass`) and add a distinct method for the command in the `CommandContext` class. This method is what should be connected/called from other code to invoke the command. Attributes: state: The `GuiState` object used to store state and pass messages. app: The `MainWindow`, available for commands that modify the app. update_callback: A callback to receive update notifications. This function should accept a list of `UpdateTopic` items. """ state: GuiState app: "MainWindow" update_callback: Optional[Callable] = None _change_stack: List = attr.ib(default=attr.Factory(list))
[docs] @classmethod def from_labels(cls, labels: Labels) -> "CommandContext": """Creates a command context for use independently of GUI app.""" state = GuiState() app = FakeApp(labels) return cls(state=state, app=app)
@property def labels(self) -> Labels: """Alias to app.labels.""" return self.app.labels
[docs] def signal_update(self, what: List[UpdateTopic]): """Calls the update callback after data has been changed.""" if callable(self.update_callback): self.update_callback(what)
[docs] def changestack_push(self, change: str = ""): """Adds to stack of changes made by user.""" # Currently the change doesn't store any data, and we're only using this # to determine if there are unsaved changes. Eventually we could use this # to support undo/redo. self._change_stack.append(change) # print(len(self._change_stack)) self.state["has_changes"] = True
[docs] def changestack_savepoint(self): """Marks that project was just saved.""" self.changestack_push("SAVE") self.state["has_changes"] = False
[docs] def changestack_clear(self): """Clears stack of changes.""" self._change_stack = list() self.state["has_changes"] = False
@property def has_any_changes(self): return len(self._change_stack) > 0
[docs] def execute(self, command: Type[AppCommand], **kwargs): """Execute command in this context, passing named arguments.""" command().execute(context=self, params=kwargs)
# File commands
[docs] def newProject(self): """Create a new project in a new window.""" self.execute(NewProject)
[docs] def openProject(self, first_open: bool = False): """ Allows use to select and then open a saved project. Args: first_open: Whether this is the first window opened. If True, then the new project is loaded into the current window rather than a new application window. Returns: None. """ self.execute(OpenProject, first_open=first_open)
[docs] def importDPK(self): """Imports DeepPoseKit datasets.""" self.execute(ImportDeepPoseKit)
[docs] def importCoco(self): """Imports COCO datasets.""" self.execute(ImportCoco)
[docs] def importDLC(self): """Imports DeepLabCut datasets.""" self.execute(ImportDeepLabCut)
[docs] def importLEAP(self): """Imports LEAP matlab datasets.""" self.execute(ImportLEAP)
[docs] def importAnalysisFile(self): """Imports SLEAP analysis hdf5 files.""" self.execute(ImportAnalysisFile)
[docs] def saveProject(self): """Show gui to save project (or save as if not yet saved).""" self.execute(SaveProject)
[docs] def saveProjectAs(self): """Show gui to save project as a new file.""" self.execute(SaveProjectAs)
[docs] def exportAnalysisFile(self): """Shows gui for exporting analysis h5 file.""" self.execute(ExportAnalysisFile)
[docs] def exportLabeledClip(self): """Shows gui for exporting clip with visual annotations.""" self.execute(ExportLabeledClip)
[docs] def exportUserLabelsPackage(self): """Gui for exporting the dataset with user-labeled images.""" self.execute(ExportUserLabelsPackage)
[docs] def exportTrainingPackage(self): """Gui for exporting the dataset with user-labeled images and suggestions.""" self.execute(ExportTrainingPackage)
[docs] def exportFullPackage(self): """Gui for exporting the dataset with any labeled frames and suggestions.""" self.execute(ExportFullPackage)
# Navigation Commands
[docs] def previousLabeledFrame(self): """Goes to labeled frame prior to current frame.""" self.execute(GoPreviousLabeledFrame)
[docs] def nextLabeledFrame(self): """Goes to labeled frame after current frame.""" self.execute(GoNextLabeledFrame)
[docs] def nextUserLabeledFrame(self): """Goes to next labeled frame with user instances.""" self.execute(GoNextUserLabeledFrame)
[docs] def nextSuggestedFrame(self): """Goes to next suggested frame.""" self.execute(GoNextSuggestedFrame)
[docs] def prevSuggestedFrame(self): """Goes to previous suggested frame.""" self.execute(GoPrevSuggestedFrame)
[docs] def nextTrackFrame(self): """Goes to next frame on which a track starts.""" self.execute(GoNextTrackFrame)
[docs] def gotoFrame(self): """Shows gui to go to frame by number.""" self.execute(GoFrameGui)
[docs] def selectToFrame(self): """Shows gui to go to frame by number.""" self.execute(SelectToFrameGui)
[docs] def gotoVideoAndFrame(self, video: Video, frame_idx: int): """Activates video and goes to frame.""" NavCommand.go_to(self, frame_idx, video)
# Editing Commands
[docs] def addVideo(self): """Shows gui for adding videos to project.""" self.execute(AddVideo)
[docs] def replaceVideo(self): """Shows gui for replacing videos to project.""" self.execute(ReplaceVideo)
[docs] def removeVideo(self): """Removes selected video from project.""" self.execute(RemoveVideo)
[docs] def openSkeleton(self): """Shows gui for loading saved skeleton into project.""" self.execute(OpenSkeleton)
[docs] def saveSkeleton(self): """Shows gui for saving skeleton from project.""" self.execute(SaveSkeleton)
[docs] def newNode(self): """Adds new node to skeleton.""" self.execute(NewNode)
[docs] def deleteNode(self): """Removes (currently selected) node from skeleton.""" self.execute(DeleteNode)
[docs] def setNodeName(self, skeleton, node, name): """Changes name of node in skeleton.""" self.execute(SetNodeName, skeleton=skeleton, node=node, name=name)
[docs] def setNodeSymmetry(self, skeleton, node, symmetry: str): """Sets node symmetry in skeleton.""" self.execute(SetNodeSymmetry, skeleton=skeleton, node=node, symmetry=symmetry)
[docs] def updateEdges(self): """Called when edges in skeleton have been changed.""" self.signal_update([UpdateTopic.skeleton])
[docs] def newEdge(self, src_node, dst_node): """Adds new edge to skeleton.""" self.execute(NewEdge, src_node=src_node, dst_node=dst_node)
[docs] def deleteEdge(self): """Removes (currently selected) edge from skeleton.""" self.execute(DeleteEdge)
[docs] def deletePredictions(self): """Deletes all predicted instances in project.""" self.execute(DeleteAllPredictions)
[docs] def deleteFramePredictions(self): """Deletes all predictions on current frame.""" self.execute(DeleteFramePredictions)
[docs] def deleteClipPredictions(self): """Deletes all predictions within selected range of video frames.""" self.execute(DeleteClipPredictions)
[docs] def deleteAreaPredictions(self): """Gui for deleting instances within some rect on frame images.""" self.execute(DeleteAreaPredictions)
[docs] def deleteLowScorePredictions(self): """Gui for deleting instances below some score threshold.""" self.execute(DeleteLowScorePredictions)
[docs] def deleteFrameLimitPredictions(self): """Gui for deleting instances beyond some number in each frame.""" self.execute(DeleteFrameLimitPredictions)
[docs] def completeInstanceNodes(self, instance: Instance): """Adds missing nodes to given instance.""" self.execute(AddMissingInstanceNodes, instance=instance)
[docs] def newInstance( self, copy_instance: Optional[Instance] = None, init_method: str = "best", location: Optional[QtCore.QPoint] = None, mark_complete: bool = False, ): """ Creates a new instance, copying node coordinates as appropriate. Args: copy_instance: The :class:`Instance` (or :class:`PredictedInstance`) which we want to copy. init_method: Method to use for positioning nodes. location: The location where instance should be added (if node init method supports custom location). """ self.execute( AddInstance, copy_instance=copy_instance, init_method=init_method, location=location, mark_complete=mark_complete, )
[docs] def setPointLocations( self, instance: Instance, nodes_locations: Dict["Node", Tuple[int, int]] ): """Sets locations for node(s) for an instance.""" self.execute( SetInstancePointLocations, instance=instance, nodes_locations=nodes_locations, )
[docs] def setInstancePointVisibility( self, instance: Instance, node: "Node", visible: bool ): """Toggles visibility set for a node for an instance.""" self.execute( SetInstancePointVisibility, instance=instance, node=node, visible=visible )
def addUserInstancesFromPredictions(self): self.execute(AddUserInstancesFromPredictions)
[docs] def deleteSelectedInstance(self): """Deletes currently selected instance.""" self.execute(DeleteSelectedInstance)
[docs] def deleteSelectedInstanceTrack(self): """Deletes all instances from track of currently selected instance.""" self.execute(DeleteSelectedInstanceTrack)
[docs] def deleteDialog(self): """Deletes using options selected in a dialog.""" self.execute(DeleteDialogCommand)
[docs] def addTrack(self): """Creates new track and moves selected instance into this track.""" self.execute(AddTrack)
[docs] def setInstanceTrack(self, new_track: "Track"): """Sets track for selected instance.""" self.execute(SetSelectedInstanceTrack, new_track=new_track)
[docs] def setTrackName(self, track: "Track", name: str): """Sets name for track.""" self.execute(SetTrackName, track=track, name=name)
[docs] def transposeInstance(self): """Transposes tracks for two instances. If there are only two instances, then this swaps tracks. Otherwise, it allows user to select the instances for which we want to swap tracks. """ self.execute(TransposeInstances)
[docs] def importPredictions(self): """Starts gui for importing another dataset into currently one.""" self.execute(MergeProject)
[docs] def generateSuggestions(self, params: Dict): """Generates suggestions using given params dictionary.""" self.execute(GenerateSuggestions, **params)
[docs] def openWebsite(self, url): """Open a website from URL using the native system browser.""" self.execute(OpenWebsite, url=url)
[docs] def checkForUpdates(self): """Check for updates online.""" self.execute(CheckForUpdates)
[docs] def openStableVersion(self): """Open the current stable version.""" self.execute(OpenStableVersion)
[docs] def openPrereleaseVersion(self): """Open the current prerelease version.""" self.execute(OpenPrereleaseVersion)
# File Commands
[docs]class NewProject(AppCommand):
[docs] @staticmethod def do_action(context: CommandContext, params: dict): window = context.app.__class__() window.showMaximized()
[docs]class OpenProject(AppCommand):
[docs] @staticmethod def do_action(context: "CommandContext", params: dict): filename = params["filename"] do_open_in_new = OPEN_IN_NEW and not params.get("first_open", False) # If no project has been loaded in this window and no changes have been # made by user, then it's an empty project window so we'll load project # into this window rather than creating a new window. if not context.state["project_loaded"] and not context.has_any_changes: do_open_in_new = False if do_open_in_new: new_window = context.app.__class__() new_window.showMaximized() new_window.loadProjectFile(filename) else: context.app.loadProjectFile(filename)
[docs] @staticmethod def ask(context: "CommandContext", params: dict) -> bool: filters = [ "SLEAP HDF5 dataset (*.slp *.h5 *.hdf5)", "JSON labels (*.json *.json.zip)", ] filename, selected_filter = FileDialog.open( context.app, dir=None, caption="Import labeled data...", filter=";;".join(filters), ) if len(filename) == 0: return False params["filename"] = filename return True
[docs]class ImportDeepPoseKit(AppCommand):
[docs] @staticmethod def do_action(context: "CommandContext", params: dict): labels = Labels.from_deepposekit( filename=params["filename"], video_path=params["video_path"], skeleton_path=params["skeleton_path"], ) new_window = context.app.__class__() new_window.showMaximized() new_window.loadLabelsObject(labels=labels)
[docs] @staticmethod def ask(context: "CommandContext", params: dict) -> bool: filters = ["HDF5 (*.h5 *.hdf5)"] filename, selected_filter = FileDialog.open( context.app, dir=None, caption="Import DeepPoseKit dataset...", filter=";;".join(filters), ) if len(filename) == 0: return False file_dir = os.path.dirname(filename) paths = [ os.path.join(file_dir, "video.mp4"), os.path.join(file_dir, "skeleton.csv"), ] missing = [not os.path.exists(path) for path in paths] if sum(missing): okay = MissingFilesDialog(filenames=paths, missing=missing).exec_() if not okay or sum(missing): return False params["filename"] = filename params["video_path"] = paths[0] params["skeleton_path"] = paths[1] return True
[docs]class ImportLEAP(AppCommand):
[docs] @staticmethod def do_action(context: "CommandContext", params: dict): labels = Labels.load_leap_matlab(filename=params["filename"],) new_window = context.app.__class__() new_window.showMaximized() new_window.loadLabelsObject(labels=labels)
[docs] @staticmethod def ask(context: "CommandContext", params: dict) -> bool: filters = ["Matlab (*.mat)"] filename, selected_filter = FileDialog.open( context.app, dir=None, caption="Import LEAP Matlab dataset...", filter=";;".join(filters), ) if len(filename) == 0: return False params["filename"] = filename return True
[docs]class ImportCoco(AppCommand):
[docs] @staticmethod def do_action(context: "CommandContext", params: dict): labels = Labels.load_coco( filename=params["filename"], img_dir=params["img_dir"], use_missing_gui=True ) new_window = context.app.__class__() new_window.showMaximized() new_window.loadLabelsObject(labels=labels)
[docs] @staticmethod def ask(context: "CommandContext", params: dict) -> bool: filters = ["JSON (*.json)"] filename, selected_filter = FileDialog.open( context.app, dir=None, caption="Import COCO dataset...", filter=";;".join(filters), ) if len(filename) == 0: return False # QtWidgets.QMessageBox( # text="Please locate the directory with image files for this dataset." # ).exec_() # # img_dir = FileDialog.openDir( # None, # directory=os.path.dirname(filename), # caption="Open Image Directory" # ) # if len(img_dir) == 0: # return False params["filename"] = filename params["img_dir"] = os.path.dirname(filename) return True
[docs]class ImportDeepLabCut(AppCommand):
[docs] @staticmethod def do_action(context: "CommandContext", params: dict): labels = Labels.load_deeplabcut(filename=params["filename"]) new_window = context.app.__class__() new_window.showMaximized() new_window.loadLabelsObject(labels=labels)
[docs] @staticmethod def ask(context: "CommandContext", params: dict) -> bool: filters = ["DeepLabCut dataset (*.yaml *.csv)"] filename, selected_filter = FileDialog.open( context.app, dir=None, caption="Import DeepLabCut dataset...", filter=";;".join(filters), ) if len(filename) == 0: return False params["filename"] = filename return True
[docs]class ImportAnalysisFile(AppCommand):
[docs] @staticmethod def do_action(context: "CommandContext", params: dict): from sleap.io.format import read labels = read( params["filename"], for_object="labels", as_format="analysis", video=params["video"], ) new_window = context.app.__class__() new_window.showMaximized() new_window.loadLabelsObject(labels=labels)
[docs] @staticmethod def ask(context: "CommandContext", params: dict) -> bool: filename, selected_filter = FileDialog.open( context.app, dir=None, caption="Import SLEAP Analysis HDF5...", filter="SLEAP Analysis HDF5 (*.h5 *.hdf5)", ) if len(filename) == 0: return False QtWidgets.QMessageBox(text="Please locate the video for this dataset.").exec_() video_param_list = ImportVideos().ask() if not video_param_list: return False params["filename"] = filename params["video"] = ImportVideos.create_video(video_param_list[0]) return True
[docs]class SaveProjectAs(AppCommand): @staticmethod def _try_save(context, labels: Labels, filename: str): """Helper function which attempts save and handles errors.""" success = False try: Labels.save_file(labels=labels, filename=filename) success = True # Mark savepoint in change stack context.changestack_savepoint() except Exception as e: message = ( f"An error occured when attempting to save:\n {e}\n\n" "Try saving your project with a different filename or in a different " "format." ) QtWidgets.QMessageBox(text=message).exec_() # Redraw. Not sure why, but sometimes we need to do this. context.app.plotFrame() return success
[docs] @classmethod def do_action(cls, context: CommandContext, params: dict): if cls._try_save(context, context.state["labels"], params["filename"]): # If save was successful context.state["filename"] = params["filename"]
[docs] @staticmethod def ask(context: CommandContext, params: dict) -> bool: default_name = context.state["filename"] or "untitled" p = PurePath(default_name) default_name = str(p.with_name(f"{p.stem} copy{p.suffix}")) filters = [ "SLEAP HDF5 dataset (*.slp)", "SLEAP JSON dataset (*.json)", "Compressed JSON (*.zip)", ] filename, selected_filter = FileDialog.save( context.app, caption="Save As...", dir=default_name, filter=";;".join(filters), ) if len(filename) == 0: return False params["filename"] = filename return True
[docs]class ExportAnalysisFile(AppCommand):
[docs] @classmethod def do_action(cls, context: CommandContext, params: dict): from sleap.io.format.sleap_analysis import SleapAnalysisAdaptor SleapAnalysisAdaptor.write(params["output_path"], context.labels)
[docs] @staticmethod def ask(context: CommandContext, params: dict) -> bool: default_name = context.state["filename"] or "untitled" p = PurePath(default_name) default_name = str(p.with_name(f"{p.stem}.analysis.h5")) filename, selected_filter = FileDialog.save( context.app, caption="Export Analysis File...", dir=default_name, filter="SLEAP Analysis HDF5 (*.h5)", ) if len(filename) == 0: return False params["output_path"] = filename return True
[docs]class SaveProject(SaveProjectAs):
[docs] @classmethod def ask(cls, context: CommandContext, params: dict) -> bool: if context.state["filename"] is not None: params["filename"] = context.state["filename"] return True # No filename (must be new project), so treat as "Save as" return SaveProjectAs.ask(context, params)
[docs]class ExportLabeledClip(AppCommand):
[docs] @staticmethod def do_action(context: CommandContext, params: dict): from sleap.io.visuals import save_labeled_video save_labeled_video( filename=params["filename"], labels=context.state["labels"], video=context.state["video"], frames=list(params["frames"]), fps=params["fps"], color_manager=params["color_manager"], show_edges=params["show edges"], scale=params["scale"], crop_size_xy=params["crop"], gui_progress=True, ) if params["open_when_done"]: # Open the file using default video playing app from sleap.util import open_file open_file(params["filename"])
[docs] @staticmethod def ask(context: CommandContext, params: dict) -> bool: from sleap.gui.dialogs.export_clip import ExportClipDialog dialog = ExportClipDialog() # Set default fps from video (if video has fps attribute) dialog.form_widget.set_form_data( dict(fps=getattr(context.state["video"], "fps", 30)) ) # Show modal dialog and get form results export_options = dialog.get_results() # Check if user hit cancel if export_options is None: return False # Use VideoWriter to determine default video type to use from sleap.io.videowriter import VideoWriter # For OpenCV we default to avi since the bundled ffmpeg # makes mp4's that most programs can't open (VLC can). default_out_filename = context.state["filename"] + ".avi" # But if we can write mpegs using sci-kit video, use .mp4 # since it has trouble writing .avi files. if VideoWriter.can_use_skvideo(): default_out_filename = context.state["filename"] + ".mp4" # Ask where use wants to save video file filename, _ = FileDialog.save( context.app, caption="Save Video As...", dir=default_out_filename, filter="Video (*.avi *mp4)", ) # Check if user hit cancel if len(filename) == 0: return False params["filename"] = filename params["fps"] = export_options["fps"] params["scale"] = export_options["scale"] params["open_when_done"] = export_options["open_when_done"] params["crop"] = None # Determine crop size relative to original size and scale # (crop size should be *final* output size, thus already scaled). w = int(context.state["video"].width * params["scale"]) h = int(context.state["video"].height * params["scale"]) if export_options["crop"] == "Half": params["crop"] = (w // 2, h // 2) elif export_options["crop"] == "Quarter": params["crop"] = (w // 4, h // 4) if export_options["use_gui_visuals"]: params["color_manager"] = context.app.color_manager else: params["color_manager"] = None params["show edges"] = context.state.get("show edges", default=True) # If user selected a clip, use that; otherwise include all frames. if context.state["has_frame_range"]: params["frames"] = range(*context.state["frame_range"]) else: params["frames"] = range(context.state["video"].frames) return True
[docs]class ExportDatasetWithImages(AppCommand): all_labeled = False suggested = False
[docs] @classmethod def do_action(cls, context: CommandContext, params: dict): win = MessageDialog("Exporting dataset with frame images...", context.app) Labels.save_file( context.state["labels"], params["filename"], default_suffix="slp", save_frame_data=True, all_labeled=cls.all_labeled, suggested=cls.suggested, ) win.hide()
[docs] @staticmethod def ask(context: CommandContext, params: dict) -> bool: filters = [ "SLEAP HDF5 dataset (*.slp *.h5)", "Compressed JSON dataset (*.json *.json.zip)", ] dirname = os.path.dirname(context.state["filename"]) basename = os.path.basename(context.state["filename"]) new_basename = f"{os.path.splitext(basename)[0]}.pkg.slp" new_filename = os.path.join(dirname, new_basename) filename, _ = FileDialog.save( context.app, caption="Save Labeled Frames As...", dir=new_filename, filters=";;".join(filters), ) if len(filename) == 0: return False params["filename"] = filename return True
[docs]class ExportUserLabelsPackage(ExportDatasetWithImages): all_labeled = False suggested = False
[docs]class ExportTrainingPackage(ExportDatasetWithImages): all_labeled = False suggested = True
[docs]class ExportFullPackage(ExportDatasetWithImages): all_labeled = True suggested = True
# Navigation Commands
[docs]class GoIteratorCommand(AppCommand): @staticmethod def _plot_if_next(context, frame_iterator: Iterator) -> bool: """Plots next frame (if there is one) from iterator. Arguments: frame_iterator: The iterator from which we'll try to get next :class:`LabeledFrame`. Returns: True if we went to next frame. """ try: next_lf = next(frame_iterator) except StopIteration: return False context.state["frame_idx"] = next_lf.frame_idx return True @staticmethod def _get_frame_iterator(context: CommandContext): raise NotImplementedError("Call to virtual method.")
[docs] @classmethod def do_action(cls, context: CommandContext, params: dict): frames = cls._get_frame_iterator(context) cls._plot_if_next(context, frames)
[docs]class GoPreviousLabeledFrame(GoIteratorCommand): @staticmethod def _get_frame_iterator(context: CommandContext): return context.labels.frames( context.state["video"], from_frame_idx=context.state["frame_idx"], reverse=True, )
[docs]class GoNextLabeledFrame(GoIteratorCommand): @staticmethod def _get_frame_iterator(context: CommandContext): return context.labels.frames( context.state["video"], from_frame_idx=context.state["frame_idx"] )
[docs]class GoNextUserLabeledFrame(GoIteratorCommand): @staticmethod def _get_frame_iterator(context: CommandContext): frames = context.labels.frames( context.state["video"], from_frame_idx=context.state["frame_idx"] ) # Filter to frames with user instances frames = filter(lambda lf: lf.has_user_instances, frames) return frames
[docs]class GoNextSuggestedFrame(NavCommand): seek_direction = 1
[docs] @classmethod def do_action(cls, context: CommandContext, params: dict): next_suggestion_frame = context.labels.get_next_suggestion( context.state["video"], context.state["frame_idx"], cls.seek_direction ) if next_suggestion_frame is not None: cls.go_to( context, next_suggestion_frame.frame_idx, next_suggestion_frame.video ) selection_idx = context.labels.get_suggestions().index( next_suggestion_frame ) context.state["suggestion_idx"] = selection_idx
[docs]class GoPrevSuggestedFrame(GoNextSuggestedFrame): seek_direction = -1
[docs]class GoNextTrackFrame(NavCommand):
[docs] @classmethod def do_action(cls, context: CommandContext, params: dict): video = context.state["video"] cur_idx = context.state["frame_idx"] track_ranges = context.labels.get_track_occupancy(video) later_tracks = [ (track_range.start, track) for track, track_range in track_ranges.items() if track_range.start is not None and track_range.start > cur_idx ] later_tracks.sort(key=operator.itemgetter(0)) if later_tracks: next_idx, next_track = later_tracks[0] cls.go_to(context, next_idx) # Select the instance in the new track lf = context.labels.find(video, next_idx, return_new=True)[0] track_instances = [ inst for inst in lf.instances_to_show if inst.track == next_track ] if track_instances: context.state["instance"] = track_instances[0]
[docs]class GoFrameGui(NavCommand):
[docs] @classmethod def do_action(cls, context: "CommandContext", params: dict): cls.go_to(context, params["frame_idx"])
[docs] @classmethod def ask(cls, context: "CommandContext", params: dict) -> bool: frame_number, okay = QtWidgets.QInputDialog.getInt( context.app, "Go To Frame...", "Frame Number:", context.state["frame_idx"] + 1, 1, context.state["video"].frames, ) params["frame_idx"] = frame_number - 1 return okay
[docs]class SelectToFrameGui(NavCommand):
[docs] @classmethod def do_action(cls, context: "CommandContext", params: dict): context.app.player.setSeekbarSelection( params["from_frame_idx"], params["to_frame_idx"] )
[docs] @classmethod def ask(cls, context: "CommandContext", params: dict) -> bool: frame_number, okay = QtWidgets.QInputDialog.getInt( context.app, "Select To Frame...", "Frame Number:", context.state["frame_idx"] + 1, 1, context.state["video"].frames, ) params["from_frame_idx"] = context.state["frame_idx"] params["to_frame_idx"] = frame_number - 1 return okay
# Editing Commands
[docs]class EditCommand(AppCommand): """Class for commands which change data in project.""" does_edits = True
[docs]class AddVideo(EditCommand): topics = [UpdateTopic.video]
[docs] @staticmethod def do_action(context: CommandContext, params: dict): import_list = params["import_list"] new_videos = ImportVideos.create_videos(import_list) video = None for video in new_videos: # Add to labels context.labels.add_video(video) context.changestack_push("add video") # Load if no video currently loaded if context.state["video"] is None: context.state["video"] = video
[docs] @staticmethod def ask(context: CommandContext, params: dict) -> bool: """Shows gui for adding video to project.""" params["import_list"] = ImportVideos().ask() return len(params["import_list"]) > 0
[docs]class ReplaceVideo(EditCommand): topics = [UpdateTopic.video]
[docs] @staticmethod def do_action(context: CommandContext, params: dict): new_paths = params["new_video_paths"] for video, new_path in zip(context.labels.videos, new_paths): if new_path != video.backend.filename: video.backend.filename = new_path video.backend.reset()
[docs] @staticmethod def ask(context: CommandContext, params: dict) -> bool: """Shows gui for replacing videos in project.""" paths = [video.backend.filename for video in context.labels.videos] okay = MissingFilesDialog(filenames=paths, replace=True).exec_() if not okay: return False params["new_video_paths"] = paths return True
[docs]class RemoveVideo(EditCommand): topics = [UpdateTopic.video]
[docs] @staticmethod def do_action(context: CommandContext, params: dict): video = params["video"] # Remove video context.labels.remove_video(video) # Update view if this was the current video if context.state["video"] == video: if len(context.labels.videos) > 0: context.state["video"] = context.labels.videos[-1] else: context.state["video"] = None
[docs] @staticmethod def ask(context: CommandContext, params: dict) -> bool: video = context.state["selected_video"] if video is None: return False # Count labeled frames for this video n = len(context.labels.find(video)) # Warn if there are labels that will be deleted if n > 0: response = QMessageBox.critical( context.app, "Removing video with labels", f"{n} labeled frames in this video will be deleted, " "are you sure you want to remove this video?", QMessageBox.Yes, QMessageBox.No, ) if response == QMessageBox.No: return False params["video"] = video return True
[docs]class OpenSkeleton(EditCommand): topics = [UpdateTopic.skeleton]
[docs] @staticmethod def ask(context: CommandContext, params: dict) -> bool: filters = ["JSON skeleton (*.json)", "HDF5 skeleton (*.h5 *.hdf5)"] filename, selected_filter = FileDialog.open( context.app, dir=None, caption="Open skeleton...", filter=";;".join(filters) ) if len(filename) == 0: return False params["filename"] = filename return True
[docs] @staticmethod def do_action(context: CommandContext, params: dict): filename = params["filename"] if filename.endswith(".json"): context.state["skeleton"] = Skeleton.load_json(filename) elif filename.endswith((".h5", ".hdf5")): sk_list = Skeleton.load_all_hdf5(filename) if len(sk_list): context.state["skeleton"] = sk_list[0] if context.state["skeleton"] not in context.labels: context.labels.skeletons.append(context.state["skeleton"])
[docs]class SaveSkeleton(AppCommand):
[docs] @staticmethod def ask(context: CommandContext, params: dict) -> bool: default_name = "skeleton.json" filters = ["JSON skeleton (*.json)", "HDF5 skeleton (*.h5 *.hdf5)"] filename, selected_filter = FileDialog.save( context.app, caption="Save As...", dir=default_name, filter=";;".join(filters), ) if len(filename) == 0: return False params["filename"] = filename return True
[docs] @staticmethod def do_action(context: CommandContext, params: dict): filename = params["filename"] if filename.endswith(".json"): context.state["skeleton"].save_json(filename) elif filename.endswith((".h5", ".hdf5")): context.state["skeleton"].save_hdf5(filename)
[docs]class NewNode(EditCommand): topics = [UpdateTopic.skeleton]
[docs] @staticmethod def do_action(context: CommandContext, params: dict): # Find new part name part_name = "new_part" i = 1 while part_name in context.state["skeleton"]: part_name = f"new_part_{i}" i += 1 # Add the node to the skeleton context.state["skeleton"].add_node(part_name)
[docs]class DeleteNode(EditCommand): topics = [UpdateTopic.skeleton]
[docs] @staticmethod def do_action(context: CommandContext, params: dict): node = context.state["selected_node"] context.state["skeleton"].delete_node(node)
[docs]class SetNodeName(EditCommand): topics = [UpdateTopic.skeleton]
[docs] @staticmethod def do_action(context: CommandContext, params: dict): node = params["node"] name = params["name"] skeleton = params["skeleton"] skeleton.relabel_node(node.name, name)
[docs]class SetNodeSymmetry(EditCommand): topics = [UpdateTopic.skeleton]
[docs] @staticmethod def do_action(context: CommandContext, params: dict): node = params["node"] symmetry = params["symmetry"] skeleton = params["skeleton"] if symmetry: skeleton.add_symmetry(node, symmetry) else: # Value was cleared by user, so delete symmetry symmetric_to = skeleton.get_symmetry(node) skeleton.delete_symmetry(node, symmetric_to)
[docs]class NewEdge(EditCommand): topics = [UpdateTopic.skeleton]
[docs] @staticmethod def do_action(context: CommandContext, params: dict): src_node = params["src_node"] dst_node = params["dst_node"] # Check if they're in the graph if ( src_node not in context.state["skeleton"] or dst_node not in context.state["skeleton"] ): return # Add edge context.state["skeleton"].add_edge(source=src_node, destination=dst_node)
[docs]class DeleteEdge(EditCommand): topics = [UpdateTopic.skeleton]
[docs] @staticmethod def ask(context: "CommandContext", params: dict) -> bool: params["edge"] = context.state["selected_edge"] return True
[docs] @staticmethod def do_action(context: CommandContext, params: dict): edge = params["edge"] # Delete edge context.state["skeleton"].delete_edge(**edge)
[docs]class InstanceDeleteCommand(EditCommand): topics = [UpdateTopic.project_instances] @staticmethod def get_frame_instance_list(context: CommandContext, params: dict): raise NotImplementedError("Call to virtual method.") @staticmethod def _confirm_deletion(context: CommandContext, lf_inst_list: List) -> bool: """Helper function to confirm before deleting instances. Args: lf_inst_list: A list of (labeled frame, instance) tuples. """ title = "Deleting instances" message = ( f"There are {len(lf_inst_list)} instances which " f"would be deleted. Are you sure you want to delete these?" ) # Confirm that we want to delete resp = QMessageBox.critical( context.app, title, message, QMessageBox.Yes, QMessageBox.No ) if resp == QMessageBox.No: return False return True @staticmethod def _do_deletion(context: CommandContext, lf_inst_list: List[int]): # Delete the instances lfs_to_remove = [] for lf, inst in lf_inst_list: context.labels.remove_instance(lf, inst, in_transaction=True) if len(lf.instances) == 0: lfs_to_remove.append(lf) context.labels.remove_frames(lfs_to_remove) # Update caches since we skipped doing this after each deletion context.labels.update_cache() # Update visuals context.changestack_push("delete instances")
[docs] @classmethod def do_action(cls, context: CommandContext, params: dict): cls._do_deletion(context, params["lf_instance_list"])
[docs] @classmethod def ask(cls, context: CommandContext, params: dict) -> bool: lf_inst_list = cls.get_frame_instance_list(context, params) params["lf_instance_list"] = lf_inst_list return cls._confirm_deletion(context, lf_inst_list)
[docs]class DeleteAllPredictions(InstanceDeleteCommand): @staticmethod def get_frame_instance_list( context: CommandContext, params: dict ) -> List[Tuple[LabeledFrame, Instance]]: return [ (lf, inst) for lf in context.labels for inst in lf if type(inst) == PredictedInstance ]
[docs]class DeleteFramePredictions(InstanceDeleteCommand): @staticmethod def _confirm_deletion(self, *args, **kwargs): # Don't require confirmation when deleting from current frame return True @staticmethod def get_frame_instance_list(context: CommandContext, params: dict): predicted_instances = [ (lf, inst) for lf in context.labels.find( context.state["video"], frame_idx=context.state["frame_idx"] ) for inst in lf if type(inst) == PredictedInstance ] return predicted_instances
[docs]class DeleteClipPredictions(InstanceDeleteCommand): @staticmethod def get_frame_instance_list(context: CommandContext, params: dict): predicted_instances = [ (lf, inst) for lf in context.labels.find( context.state["video"], frame_idx=range(*context.state["frame_range"]) ) for inst in lf if type(inst) == PredictedInstance ] return predicted_instances
[docs]class DeleteAreaPredictions(InstanceDeleteCommand): @staticmethod def get_frame_instance_list(context: CommandContext, params: dict): min_corner = params["min_corner"] max_corner = params["max_corner"] def is_bounded(inst): points_array = inst.points_array valid_points = points_array[~np.isnan(points_array).any(axis=1)] is_gt_min = np.all(valid_points >= min_corner) is_lt_max = np.all(valid_points <= max_corner) return is_gt_min and is_lt_max # Find all instances contained in selected area predicted_instances = [ (lf, inst) for lf in context.labels.find(context.state["video"]) for inst in lf if type(inst) == PredictedInstance and is_bounded(inst) ] return predicted_instances @classmethod def ask_and_do(cls, context: CommandContext, params: dict): # Callback to delete after area has been selected def delete_area_callback(x0, y0, x1, y1): context.app.updateStatusMessage() # Make sure there was an area selected if x0 == x1 or y0 == y1: return params["min_corner"] = (x0, y0) params["max_corner"] = (x1, y1) predicted_instances = cls.get_frame_instance_list(context, params) if cls._confirm_deletion(context, predicted_instances): params["lf_instance_list"] = predicted_instances cls.do_with_signal(context, params) # Prompt the user to select area context.app.updateStatusMessage( f"Please select the area from which to remove instances. This will be applied to all frames." ) context.app.player.onAreaSelection(delete_area_callback)
[docs]class DeleteLowScorePredictions(InstanceDeleteCommand): @staticmethod def get_frame_instance_list(context: CommandContext, params: dict): score_thresh = params["score_threshold"] predicted_instances = [ (lf, inst) for lf in context.labels.find(context.state["video"]) for inst in lf if type(inst) == PredictedInstance and inst.score < score_thresh ] return predicted_instances
[docs] @classmethod def ask(cls, context: CommandContext, params: dict) -> bool: score_thresh, okay = QtWidgets.QInputDialog.getDouble( context.app, "Delete Instances with Low Score...", "Score Below:", 1, 0, 100 ) if okay: params["score_threshold"] = score_thresh return super().ask(context, params)
[docs]class DeleteFrameLimitPredictions(InstanceDeleteCommand): @staticmethod def get_frame_instance_list(context: CommandContext, params: dict): count_thresh = params["count_threshold"] predicted_instances = [] # Find all instances contained in selected area for lf in context.labels.find(context.state["video"]): if len(lf.predicted_instances) > count_thresh: # Get all but the count_thresh many instances with the highest score extra_instances = sorted( lf.predicted_instances, key=operator.attrgetter("score") )[:-count_thresh] predicted_instances.extend([(lf, inst) for inst in extra_instances]) return predicted_instances
[docs] @classmethod def ask(cls, context: CommandContext, params: dict) -> bool: count_thresh, okay = QtWidgets.QInputDialog.getInt( context.app, "Limit Instances in Frame...", "Maximum instances in a frame:", 3, 1, 100, ) if okay: params["count_threshold"] = count_thresh return super().ask(context, params)
[docs]class TransposeInstances(EditCommand): topics = [UpdateTopic.project_instances, UpdateTopic.tracks]
[docs] @classmethod def do_action(cls, context: CommandContext, params: dict): instances = params["instances"] if len(instances) != 2: return # Swap tracks for current and subsequent frames when we have tracks old_track, new_track = instances[0].track, instances[1].track if old_track is not None and new_track is not None: frame_range = (context.state["frame_idx"], context.state["video"].frames) context.labels.track_swap( context.state["video"], new_track, old_track, frame_range )
@classmethod def ask_and_do(cls, context: CommandContext, params: dict): def on_each(instances: list): word = "next" if len(instances) else "first" context.app.updateStatusMessage( f"Please select the {word} instance to transpose..." ) def on_success(instances: list): params["instances"] = instances cls.do_with_signal(context, params) if len(context.state["labeled_frame"].instances) < 2: return # If there are just two instances, transpose them. if len(context.state["labeled_frame"].instances) == 2: params["instances"] = context.state["labeled_frame"].instances cls.do_with_signal(context, params) # If there are more than two, then we need the user to select the instances. else: context.app.player.onSequenceSelect( seq_len=2, on_success=on_success, on_each=on_each, on_failure=lambda x: context.app.updateStatusMessage(), )
[docs]class DeleteSelectedInstance(EditCommand): topics = [UpdateTopic.frame, UpdateTopic.project_instances, UpdateTopic.suggestions]
[docs] @staticmethod def do_action(context: CommandContext, params: dict): selected_inst = context.state["instance"] if selected_inst is None: return context.labels.remove_instance(context.state["labeled_frame"], selected_inst)
[docs]class DeleteSelectedInstanceTrack(EditCommand): topics = [ UpdateTopic.project_instances, UpdateTopic.tracks, UpdateTopic.suggestions, ]
[docs] @staticmethod def do_action(context: CommandContext, params: dict): selected_inst = context.state["instance"] if selected_inst is None: return track = selected_inst.track context.labels.remove_instance(context.state["labeled_frame"], selected_inst) if track is not None: # remove any instance on this track for lf in context.labels.find(context.state["video"]): track_instances = filter(lambda inst: inst.track == track, lf.instances) for inst in track_instances: context.labels.remove_instance(lf, inst)
[docs]class DeleteDialogCommand(EditCommand): topics = [ UpdateTopic.project_instances, ] @staticmethod def ask_and_do(context: CommandContext, params: dict): if DeleteDialog(context).exec_(): context.signal_update([UpdateTopic.project_instances])
[docs]class AddTrack(EditCommand): topics = [UpdateTopic.tracks]
[docs] @staticmethod def do_action(context: CommandContext, params: dict): track_numbers_used = [ int(track.name) for track in context.labels.tracks if track.name.isnumeric() ] next_number = max(track_numbers_used, default=0) + 1 new_track = Track(spawned_on=context.state["frame_idx"], name=str(next_number)) context.labels.add_track(context.state["video"], new_track) context.execute(SetSelectedInstanceTrack, new_track=new_track)
[docs]class SetSelectedInstanceTrack(EditCommand): topics = [UpdateTopic.tracks]
[docs] @staticmethod def do_action(context: CommandContext, params: dict): selected_instance = context.state["instance"] new_track = params["new_track"] if selected_instance is None: return old_track = selected_instance.track # When setting track for an instance that doesn't already have a track set, # just set for selected instance. if old_track is None: # Move anything already in the new track out of it new_track_instances = context.labels.find_track_occupancy( video=context.state["video"], track=new_track, frame_range=( context.state["frame_idx"], context.state["frame_idx"] + 1, ), ) for instance in new_track_instances: instance.track = None # Move selected instance into new track context.labels.track_set_instance( context.state["labeled_frame"], selected_instance, new_track ) # When the instance does already have a track, then we want to update # the track for a range of frames. else: # Determine range that should be affected if context.state["has_frame_range"]: # If range is selected in seekbar, use that frame_range = tuple(context.state["frame_range"]) else: # Otherwise, range is current to last frame frame_range = ( context.state["frame_idx"], context.state["video"].frames, ) # Do the swap context.labels.track_swap( context.state["video"], new_track, old_track, frame_range ) # Make sure the originally selected instance is still selected context.state["instance"] = selected_instance
[docs]class SetTrackName(EditCommand): topics = [UpdateTopic.tracks, UpdateTopic.frame]
[docs] @staticmethod def do_action(context: CommandContext, params: dict): track = params["track"] name = params["name"] track.name = name
[docs]class GenerateSuggestions(EditCommand): topics = [UpdateTopic.suggestions]
[docs] @classmethod def do_action(cls, context: CommandContext, params: dict): win = MessageDialog("Generating list of suggested frames...", context.app) new_suggestions = VideoFrameSuggestions.suggest( labels=context.labels, params=params ) context.labels.set_suggestions(new_suggestions) win.hide()
[docs]class MergeProject(EditCommand): topics = [UpdateTopic.all] @classmethod def ask_and_do(cls, context: CommandContext, params: dict): filters = [ "SLEAP HDF5 dataset (*.slp *.h5 *.hdf5)", "SLEAP JSON dataset (*.json *.json.zip)", ] filenames, selected_filter = FileDialog.openMultiple( context.app, dir=None, caption="Import labeled data...", filter=";;".join(filters), ) if len(filenames) == 0: return for filename in filenames: gui_video_callback = Labels.make_gui_video_callback( search_paths=[os.path.dirname(filename)] ) new_labels = Labels.load_file(filename, video_search=gui_video_callback) # Merging data is handled by MergeDialog MergeDialog(base_labels=context.labels, new_labels=new_labels).exec_() cls.do_with_signal(context, params)
[docs]class AddInstance(EditCommand): topics = [UpdateTopic.frame, UpdateTopic.project_instances, UpdateTopic.suggestions] @staticmethod def get_previous_frame_index(context: CommandContext) -> Optional[int]: frames = context.labels.frames( context.state["video"], from_frame_idx=context.state["frame_idx"], reverse=True, ) try: next_idx = next(frames).frame_idx except: return return next_idx
[docs] @classmethod def do_action(cls, context: CommandContext, params: dict): copy_instance = params.get("copy_instance", None) init_method = params.get("init_method", "best") location = params.get("location", None) mark_complete = params.get("mark_complete", False) if context.state["labeled_frame"] is None: return # FIXME: filter by skeleton type from_predicted = copy_instance from_prev_frame = False if init_method == "best" and copy_instance is None: selected_inst = context.state["instance"] if selected_inst is not None: # If the user has selected an instance, copy that one. copy_instance = selected_inst from_predicted = copy_instance if ( init_method == "best" and copy_instance is None ) or init_method == "prediction": unused_predictions = context.state["labeled_frame"].unused_predictions if len(unused_predictions): # If there are predicted instances that don't correspond to an instance # in this frame, use the first predicted instance without matching instance. copy_instance = unused_predictions[0] from_predicted = copy_instance if ( init_method == "best" and copy_instance is None ) or init_method == "prior_frame": # Otherwise, if there are instances in previous frames, # copy the points from one of those instances. prev_idx = cls.get_previous_frame_index(context) if prev_idx is not None: prev_instances = context.labels.find( context.state["video"], prev_idx, return_new=True )[0].instances if len(prev_instances) > len(context.state["labeled_frame"].instances): # If more instances in previous frame than current, then use the # first unmatched instance. copy_instance = prev_instances[ len(context.state["labeled_frame"].instances) ] from_prev_frame = True elif init_method == "best" and ( context.state["labeled_frame"].instances ): # Otherwise, if there are already instances in current frame, # copy the points from the last instance added to frame. copy_instance = context.state["labeled_frame"].instances[-1] elif len(prev_instances): # Otherwise use the last instance added to previous frame. copy_instance = prev_instances[-1] from_prev_frame = True from_predicted = from_predicted if hasattr(from_predicted, "score") else None # Now create the new instance new_instance = Instance( skeleton=context.state["skeleton"], from_predicted=from_predicted, frame=context.state["labeled_frame"], ) has_missing_nodes = False # go through each node in skeleton for node in context.state["skeleton"].node_names: # if we're copying from a skeleton that has this node if ( copy_instance is not None and node in copy_instance and not copy_instance[node].isnan() ): # just copy x, y, and visible # we don't want to copy a PredictedPoint or score attribute new_instance[node] = Point( x=copy_instance[node].x, y=copy_instance[node].y, visible=copy_instance[node].visible, complete=mark_complete, ) else: has_missing_nodes = True if has_missing_nodes: # mark the node as not "visible" if we're copying from a predicted instance without this node is_visible = copy_instance is None or (not hasattr(copy_instance, "score")) if init_method == "force_directed": AddMissingInstanceNodes.add_force_directed_nodes( context=context, instance=new_instance, visible=is_visible, center_point=location, ) elif init_method == "random": AddMissingInstanceNodes.add_random_nodes( context=context, instance=new_instance, visible=is_visible ) elif init_method == "template": AddMissingInstanceNodes.add_nodes_from_template( context=context, instance=new_instance, visible=is_visible, center_point=location, ) else: AddMissingInstanceNodes.add_best_nodes( context=context, instance=new_instance, visible=is_visible ) # If we're copying a predicted instance or from another frame, copy the track if hasattr(copy_instance, "score") or from_prev_frame: new_instance.track = copy_instance.track # Add the instance context.labels.add_instance(context.state["labeled_frame"], new_instance) if context.state["labeled_frame"] not in context.labels.labels: context.labels.append(context.state["labeled_frame"])
[docs]class SetInstancePointLocations(EditCommand): """Sets locations for node(s) for an instance. Note: It's important that this command does *not* update the visual scene, since this would redraw the frame and create new visual objects. The calling code is responsible for updating the visual scene. Params: instance: The instance nodes_locations: A dictionary of data to set * keys are nodes (or node names) * values are (x, y) coordinate tuples. """ topics = []
[docs] @classmethod def do_action(cls, context: "CommandContext", params: dict): instance = params["instance"] nodes_locations = params["nodes_locations"] for node, (x, y) in nodes_locations.items(): if node in instance: instance[node].x = x instance[node].y = y
[docs]class SetInstancePointVisibility(EditCommand): """Toggles visibility set for a node for an instance. Note: It's important that this command does *not* update the visual scene, since this would redraw the frame and create new visual objects. The calling code is responsible for updating the visual scene. Params: instance: The instance node: The `Node` (or name string) visible: Whether to set or clear visibility for node """ topics = []
[docs] @classmethod def do_action(cls, context: "CommandContext", params: dict): instance = params["instance"] node = params["node"] visible = params["visible"] instance[node].visible = visible
[docs]class AddMissingInstanceNodes(EditCommand): topics = [UpdateTopic.frame]
[docs] @classmethod def do_action(cls, context: CommandContext, params: dict): instance = params["instance"] visible = params.get("visible", False) cls.add_best_nodes(context, instance, visible)
@classmethod def add_best_nodes(cls, context, instance, visible): # Try placing missing nodes using a "template" instance cls.add_nodes_from_template(context, instance, visible) # If the "template" instance has missing nodes (i.e., a node that isn't # labeled on any of the instances we used to generate the template), # then adding nodes from the template may still result in missing nodes. # So we'll use random placement for anything that's still missing. cls.add_random_nodes(context, instance, visible) @classmethod def add_random_nodes(cls, context, instance, visible): # the rect that's currently visible in the window view in_view_rect = context.app.player.getVisibleRect() for node in context.state["skeleton"].nodes: if node not in instance.nodes or instance[node].isnan(): # pick random points within currently zoomed view x, y = cls.get_xy_in_rect(in_view_rect) # set point for node instance[node] = Point(x=x, y=y, visible=visible)
[docs] @staticmethod def get_xy_in_rect(rect: QtCore.QRectF): """Returns random x, y coordinates within given rect.""" x = rect.x() + (rect.width() * 0.1) + (np.random.rand() * rect.width() * 0.8) y = rect.y() + (rect.height() * 0.1) + (np.random.rand() * rect.height() * 0.8) return x, y
[docs] @staticmethod def get_rect_center_xy(rect: QtCore.QRectF): """Returns x, y at center of rect."""
@classmethod def add_nodes_from_template( cls, context, instance, visible: bool = False, center_point: QtCore.QPoint = None, ): from sleap.info import align # Get the "template" instance template_points = context.labels.get_template_instance_points( skeleton=instance.skeleton ) # Align the template on to the current instance with missing points if instance.points: aligned_template = align.align_instance_points( source_points_array=template_points, target_points_array=instance.points_array, ) else: template_mean = np.nanmean(template_points, axis=0) center_point = center_point or context.app.player.getVisibleRect().center() center = np.array([center_point.x(), center_point.y()]) aligned_template = template_points + (center - template_mean) # Make missing points from the aligned template for i, node in enumerate(instance.skeleton.nodes): if node not in instance: x, y = aligned_template[i] instance[node] = Point(x=x, y=y, visible=visible) @classmethod def add_force_directed_nodes( cls, context, instance, visible, center_point: QtCore.QPoint = None ): import networkx as nx center_point = center_point or context.app.player.getVisibleRect().center() center_tuple = (center_point.x(), center_point.y()) node_positions = nx.spring_layout( G=context.state["skeleton"].graph, center=center_tuple, scale=50 ) for node, pos in node_positions.items(): instance[node] = Point(x=pos[0], y=pos[1], visible=visible)
[docs]class AddUserInstancesFromPredictions(EditCommand): topics = [UpdateTopic.frame, UpdateTopic.project_instances] @staticmethod def make_instance_from_predicted_instance( copy_instance: PredictedInstance, ) -> Instance: # create the new instance new_instance = Instance( skeleton=copy_instance.skeleton, from_predicted=copy_instance, frame=copy_instance.frame, ) # go through each node in skeleton for node in new_instance.skeleton.node_names: # if we're copying from a skeleton that has this node if node in copy_instance and not copy_instance[node].isnan(): # just copy x, y, and visible # we don't want to copy a PredictedPoint or score attribute new_instance[node] = Point( x=copy_instance[node].x, y=copy_instance[node].y, visible=copy_instance[node].visible, complete=False, ) # copy the track new_instance.track = copy_instance.track return new_instance
[docs] @classmethod def do_action(cls, context: CommandContext, params: dict): if context.state["labeled_frame"] is None: return new_instances = [] unused_predictions = context.state["labeled_frame"].unused_predictions for predicted_instance in unused_predictions: new_instances.append( cls.make_instance_from_predicted_instance(predicted_instance) ) # Add the instances for new_instance in new_instances: context.labels.add_instance(context.state["labeled_frame"], new_instance)
[docs]class OpenWebsite(AppCommand):
[docs] @staticmethod def do_action(context: CommandContext, params: dict): QtGui.QDesktopServices.openUrl(QtCore.QUrl(params["url"]))
[docs]class CheckForUpdates(AppCommand):
[docs] @staticmethod def do_action(context: CommandContext, params: dict): success = context.app.release_checker.check_for_releases() if success: stable = context.app.release_checker.latest_stable prerelease = context.app.release_checker.latest_prerelease context.state["stable_version_menu"].setText(f" Stable: {stable.version}") context.state["stable_version_menu"].setEnabled(True) context.state["prerelease_version_menu"].setText( f" Prerelease: {prerelease.version}" ) context.state["prerelease_version_menu"].setEnabled(True)
# TODO: Provide GUI feedback about result.
[docs]class OpenStableVersion(AppCommand):
[docs] @staticmethod def do_action(context: CommandContext, params: dict): rls = context.app.release_checker.latest_stable if rls is not None: context.openWebsite(rls.url)
[docs]class OpenPrereleaseVersion(AppCommand):
[docs] @staticmethod def do_action(context: CommandContext, params: dict): rls = context.app.release_checker.latest_prerelease if rls is not None: context.openWebsite(rls.url)