"""Training-related tf.keras callbacks."""
import jsonpickle
import logging
import numpy as np
import tensorflow as tf
import zmq
import io
import os
import matplotlib
import matplotlib.pyplot as plt
from typing import Text, Callable, Optional
logger = logging.getLogger(__name__)
[docs]class TrainingControllerZMQ(tf.keras.callbacks.Callback):
def __init__(self, address="tcp://127.0.0.1:9000", topic="", poll_timeout=10):
self.address = address
self.topic = topic
self.timeout = poll_timeout
# Initialize ZMQ
self.context = zmq.Context()
self.socket = self.context.socket(zmq.SUB)
self.socket.subscribe(self.topic)
self.socket.connect(self.address)
logger.info(
f"Training controller subscribed to: {self.address} (topic: {self.topic})"
)
# TODO: catch/throw exception about failure to connect
# Callback initialization
super().__init__()
def __del__(self):
logger.info(f"Closing the training controller socket/context.")
self.socket.close()
self.context.term()
[docs] def on_batch_end(self, batch, logs=None):
""" Called at the end of a training batch. """
if self.socket.poll(self.timeout, zmq.POLLIN):
msg = jsonpickle.decode(self.socket.recv_string())
logger.info(f"Received control message: {msg}")
# Stop training
if msg["command"] == "stop":
# self.model is set when training begins in Model.fit_generator
self.model.stop_training = True
# Adjust learning rate
elif msg["command"] == "set_lr":
self.set_lr(msg["lr"])
[docs] def set_lr(self, lr):
""" Adjust the model learning rate.
This is the based off of the implementation used in the native learning rate
scheduling callbacks.
"""
if not isinstance(lr, (float, np.float32, np.float64)):
lr = np.array(lr).astype(np.float64)
tf.keras.backend.set_value(self.model.optimizer.lr, lr)
[docs]class ProgressReporterZMQ(tf.keras.callbacks.Callback):
def __init__(self, address="tcp://127.0.0.1:9001", what="not_set"):
self.address = address
self.what = what
# Initialize ZMQ
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PUB)
self.socket.connect(self.address)
logger.info(f"Progress reporter publishing on: {self.address} for: {self.what}")
# TODO: catch/throw exception about failure to connect
# Callback initialization
super().__init__()
def __del__(self):
logger.info(f"Closing the reporter controller/context.")
self.socket.setsockopt(zmq.LINGER, 0)
# url = self.socket.LAST_ENDPOINT
# self.socket.unbind(url)
self.socket.close()
self.context.term()
[docs] def on_train_begin(self, logs=None):
"""Called at the beginning of training.
Subclasses should override for any actions to run.
# Arguments
logs: dict, currently no data is passed to this argument for this method
but that may change in the future.
"""
self.socket.send_string(
jsonpickle.encode(dict(what=self.what, event="train_begin", logs=logs))
)
[docs] def on_batch_begin(self, batch, logs=None):
"""A backwards compatibility alias for `on_train_batch_begin`."""
# self.logger.info("batch_begin")
self.socket.send_string(
jsonpickle.encode(
dict(what=self.what, event="batch_begin", batch=batch, logs=logs)
)
)
[docs] def on_batch_end(self, batch, logs=None):
"""A backwards compatibility alias for `on_train_batch_end`."""
self.socket.send_string(
jsonpickle.encode(
dict(what=self.what, event="batch_end", batch=batch, logs=logs)
)
)
[docs] def on_epoch_begin(self, epoch, logs=None):
"""Called at the start of an epoch.
Subclasses should override for any actions to run. This function should only
be called during train mode.
# Arguments
epoch: integer, index of epoch.
logs: dict, currently no data is passed to this argument for this method
but that may change in the future.
"""
self.socket.send_string(
jsonpickle.encode(
dict(what=self.what, event="epoch_begin", epoch=epoch, logs=logs)
)
)
[docs] def on_epoch_end(self, epoch, logs=None):
"""Called at the end of an epoch.
Subclasses should override for any actions to run. This function should only
be called during train mode.
# Arguments
epoch: integer, index of epoch.
logs: dict, metric results for this training epoch, and for the
validation epoch if validation is performed. Validation result keys
are prefixed with `val_`.
"""
self.socket.send_string(
jsonpickle.encode(
dict(what=self.what, event="epoch_end", epoch=epoch, logs=logs)
)
)
[docs] def on_train_end(self, logs=None):
"""Called at the end of training.
Subclasses should override for any actions to run.
# Arguments
logs: dict, currently no data is passed to this argument for this method
but that may change in the future.
"""
self.socket.send_string(
jsonpickle.encode(dict(what=self.what, event="train_end", logs=logs))
)
[docs]class ModelCheckpointOnEvent(tf.keras.callbacks.Callback):
"""Callback for model checkpointing on a fixed event.
Attributes:
filepath: Path to save model to.
event: Event to trigger model saving ("train_start" or "train_end").
"""
def __init__(self, filepath: str, event: str = "train_end"):
self.filepath = filepath
self.event = event
# Callback initialization
super().__init__()
[docs] def on_train_begin(self, logs=None):
"""Called at the start of training."""
if self.event == "train_begin":
self.model.save(self.filepath)
[docs] def on_epoch_end(self, epoch, logs=None):
"""Called at the end of each epoch."""
if self.event == "epoch_end":
if "%" in self.filepath:
self.model.save(self.filepath % epoch)
else:
self.model.save(self.filepath)
[docs] def on_train_end(self, logs=None):
"""Called at the end of training."""
if self.event == "train_end":
self.model.save(self.filepath)
[docs]class TensorBoardMatplotlibWriter(tf.keras.callbacks.Callback):
"""Callback for writing image summaries with visualizations during training.
Attributes:
logdir: Path to log directory.
plot_fn: Function with no arguments that returns a matplotlib figure handle.
tag: Text to append to the summary label in TensorBoard.
"""
def __init__(
self,
log_dir: Text,
plot_fn: Callable[[], matplotlib.figure.Figure],
tag: Text = "viz",
):
self.log_dir = log_dir
self.plot_fn = plot_fn
self.tag = tag
self.file_writer = tf.summary.create_file_writer(self.log_dir)
# Callback initialization
super().__init__()
[docs] def on_epoch_end(self, epoch, logs=None):
"""Called at the end of each epoch."""
# Call plotting function.
figure = self.plot_fn()
# Render to in-memory PNG.
image_buffer = io.BytesIO()
figure.savefig(image_buffer, format="png", pad_inches=0)
plt.close(figure)
# Convert PNG to tensor.
image_buffer.seek(0)
image_tensor = tf.expand_dims(
tf.image.decode_png(image_buffer.getvalue(), channels=4), axis=0
)
# Log to TensorBoard.
with self.file_writer.as_default():
tf.summary.image(name=self.tag, data=image_tensor, step=epoch)
[docs]class MatplotlibSaver(tf.keras.callbacks.Callback):
"""Callback for saving images rendered with matplotlib during training.
This is useful for saving visualizations of the training to disk. It will be called
at the end of each epoch.
Attributes:
plot_fn: Function with no arguments that returns a matplotlib figure handle.
See `sleap.nn.training.Trainer.visualize_predictions` for example.
save_folder: Path to a directory to save images to. This folder will be created
if it does not exist.
prefix: String that will be prepended to the filenames. This is useful for
indicating which dataset the visualization was sampled from, for example.
Notes:
This will save images with the naming pattern:
"{save_folder}/{prefix}.{epoch}.png"
or:
"{save_folder}/{epoch}.png"
if a prefix is not specified.
"""
def __init__(
self,
save_folder: Text,
plot_fn: Callable[[], matplotlib.figure.Figure],
prefix: Optional[Text] = None,
):
"""Initialize callback."""
self.save_folder = save_folder
self.plot_fn = plot_fn
self.prefix = prefix
super().__init__()
[docs] def on_epoch_end(self, epoch, logs=None):
"""Save figure at the end of each epoch."""
# Call plotting function.
figure = self.plot_fn()
# Check if output folder exists.
if not os.path.exists(self.save_folder):
os.makedirs(self.save_folder)
# Build filename.
prefix = ""
if self.prefix is not None:
prefix = self.prefix + "."
figure_path = os.path.join(self.save_folder, f"{prefix}{epoch:04d}.png")
# Save rendered figure.
figure.savefig(figure_path, format="png", pad_inches=0)
plt.close(figure)