"""GUI for monitoring training progress interactively."""
from collections import deque
import numpy as np
from time import time, sleep
import zmq
import jsonpickle
import logging
from typing import Optional
from PySide2 import QtCore, QtWidgets, QtGui, QtCharts
logger = logging.getLogger(__name__)
[docs]class LossViewer(QtWidgets.QMainWindow):
"""Qt window for showing in-progress training metrics sent over ZMQ."""
on_epoch = QtCore.Signal()
def __init__(
self,
zmq_context: Optional[zmq.Context] = None,
show_controller=True,
parent=None,
):
super(LossViewer, self).__init__(parent)
self.show_controller = show_controller
self.stop_button = None
self.redraw_batch_interval = 40
self.batches_to_show = -1 # -1 to show all
self.ignore_outliers = False
self.log_scale = True
self.reset()
self.setup_zmq(zmq_context)
def __del__(self):
self.unbind()
[docs] def close(self):
self.unbind()
super(LossViewer, self).close()
def unbind(self):
# close the zmq socket
if self.sub is not None:
self.sub.unbind(self.sub.LAST_ENDPOINT)
self.sub.close()
self.sub = None
if self.zmq_ctrl is not None:
url = self.zmq_ctrl.LAST_ENDPOINT
self.zmq_ctrl.unbind(url)
self.zmq_ctrl.close()
self.zmq_ctrl = None
# if we started out own zmq context, terminate it
if not self.ctx_given and self.ctx is not None:
self.ctx.term()
self.ctx = None
def reset(self, what=""):
self.chart = QtCharts.QtCharts.QChart()
self.series = dict()
self.color = dict()
self.series["batch"] = QtCharts.QtCharts.QScatterSeries()
self.series["epoch_loss"] = QtCharts.QtCharts.QLineSeries()
self.series["val_loss"] = QtCharts.QtCharts.QLineSeries()
self.series["batch"].setName("Batch Training Loss")
self.series["epoch_loss"].setName("Epoch Training Loss")
self.series["val_loss"].setName("Epoch Validation Loss")
self.color["batch"] = QtGui.QColor("blue")
self.color["epoch_loss"] = QtGui.QColor("green")
self.color["val_loss"] = QtGui.QColor("red")
for s in self.series:
self.series[s].pen().setColor(self.color[s])
self.series["batch"].setMarkerSize(8.0)
self.chart.addSeries(self.series["batch"])
self.chart.addSeries(self.series["epoch_loss"])
self.chart.addSeries(self.series["val_loss"])
axisX = QtCharts.QtCharts.QValueAxis()
axisX.setLabelFormat("%d")
axisX.setTitleText("Batches")
self.chart.addAxis(axisX, QtCore.Qt.AlignBottom)
# create the different Y axes that can be used
self.axisY = dict()
self.axisY["log"] = QtCharts.QtCharts.QLogValueAxis()
self.axisY["log"].setBase(10)
self.axisY["linear"] = QtCharts.QtCharts.QValueAxis()
# settings that apply to all Y axes
for axisY in self.axisY.values():
axisY.setLabelFormat("%f")
axisY.setLabelsVisible(True)
axisY.setMinorTickCount(1)
axisY.setTitleText("Loss")
# use the default Y axis
axisY = self.axisY["log"] if self.log_scale else self.axisY["linear"]
self.chart.addAxis(axisY, QtCore.Qt.AlignLeft)
for series in self.chart.series():
series.attachAxis(axisX)
series.attachAxis(axisY)
# self.chart.legend().hide()
self.chart.legend().setVisible(True)
self.chart.legend().setAlignment(QtCore.Qt.AlignTop)
self.chartView = QtCharts.QtCharts.QChartView(self.chart)
self.chartView.setRenderHint(QtGui.QPainter.Antialiasing)
layout = QtWidgets.QVBoxLayout()
layout.addWidget(self.chartView)
if self.show_controller:
control_layout = QtWidgets.QHBoxLayout()
field = QtWidgets.QCheckBox("Log Scale")
field.setChecked(self.log_scale)
field.stateChanged.connect(lambda x: self.toggle("log_scale"))
control_layout.addWidget(field)
field = QtWidgets.QCheckBox("Ignore Outliers")
field.setChecked(self.ignore_outliers)
field.stateChanged.connect(lambda x: self.toggle("ignore_outliers"))
control_layout.addWidget(field)
control_layout.addWidget(QtWidgets.QLabel("Batches to Show:"))
# add field for how many batches to show in chart
field = QtWidgets.QComboBox()
# add options
self.batch_options = "200,1000,5000,All".split(",")
for opt in self.batch_options:
field.addItem(opt)
# set field to currently set value
cur_opt_str = (
"All" if self.batches_to_show < 0 else str(self.batches_to_show)
)
if cur_opt_str in self.batch_options:
field.setCurrentText(cur_opt_str)
# connection action for when user selects another option
field.currentIndexChanged.connect(
lambda x: self.set_batches_to_show(self.batch_options[x])
)
# store field as property and add to layout
self.batches_to_show_field = field
control_layout.addWidget(self.batches_to_show_field)
control_layout.addStretch(1)
self.stop_button = QtWidgets.QPushButton("Stop Training")
self.stop_button.clicked.connect(self.stop)
control_layout.addWidget(self.stop_button)
widget = QtWidgets.QWidget()
widget.setLayout(control_layout)
layout.addWidget(widget)
wid = QtWidgets.QWidget()
wid.setLayout(layout)
self.setCentralWidget(wid)
self.X = []
self.Y = []
self.t0 = None
self.current_job_output_type = what
self.epoch = 0
self.epoch_size = 1
self.last_epoch_val_loss = None
self.last_batch_number = 0
self.is_running = False
def toggle(self, what):
if what == "log_scale":
self.log_scale = not self.log_scale
self.update_y_axis()
elif what == "ignore_outliers":
self.ignore_outliers = not self.ignore_outliers
elif what == "entire_history":
if self.batches_to_show > 0:
self.batches_to_show = -1
else:
self.batches_to_show = 200
def set_batches_to_show(self, val):
if val.isdigit():
self.batches_to_show = int(val)
else:
self.batches_to_show = -1
def update_y_axis(self):
to = "log" if self.log_scale else "linear"
# remove other axes
for name, axisY in self.axisY.items():
if name != to:
if axisY in self.chart.axes():
self.chart.removeAxis(axisY)
for series in self.chart.series():
if axisY in series.attachedAxes():
series.detachAxis(axisY)
# add axis
axisY = self.axisY[to]
self.chart.addAxis(axisY, QtCore.Qt.AlignLeft)
for series in self.chart.series():
series.attachAxis(axisY)
def setup_zmq(self, zmq_context: Optional[zmq.Context]):
# Keep track of whether we're using an existing context (which we won't
# close when done) or are creating our own (which we should close).
self.ctx_given = zmq_context is not None
self.ctx = zmq.Context() if zmq_context is None else zmq_context
# Progress monitoring, SUBSCRIBER
self.sub = self.ctx.socket(zmq.SUB)
self.sub.subscribe("")
self.sub.bind("tcp://127.0.0.1:9001")
# Controller, PUBLISHER
self.zmq_ctrl = None
if self.show_controller:
self.zmq_ctrl = self.ctx.socket(zmq.PUB)
self.zmq_ctrl.bind("tcp://127.0.0.1:9000")
# Set timer to poll for messages every 20 milliseconds
self.timer = QtCore.QTimer()
self.timer.timeout.connect(self.check_messages)
self.timer.start(20)
[docs] def stop(self):
"""Action to stop training."""
if self.zmq_ctrl is not None:
# send command to stop training
logger.info("Sending command to stop training")
self.zmq_ctrl.send_string(jsonpickle.encode(dict(command="stop")))
# Disable the button
if self.stop_button is not None:
self.stop_button.setText("Stopping...")
self.stop_button.setEnabled(False)
[docs] def add_datapoint(self, x, y, which="batch"):
"""
Adds data point to graph.
Args:
x: typically the batch number (out of all epochs, not just current)
y: typically the loss value
which: type of data point we're adding, possible values are
* batch (loss for batch)
* epoch_loss (loss for entire epoch)
* val_loss (validation loss for for epoch)
"""
# Keep track of all batch points
if which == "batch":
self.X.append(x)
self.Y.append(y)
# Redraw batch at intervals (faster than plotting each)
if x % self.redraw_batch_interval == 0:
if self.batches_to_show < 0 or len(self.X) < self.batches_to_show:
xs, ys = self.X, self.Y
else:
xs, ys = (
self.X[-self.batches_to_show :],
self.Y[-self.batches_to_show :],
)
points = [QtCore.QPointF(x, y) for x, y in zip(xs, ys) if y > 0]
self.series["batch"].replace(points)
# Set X scale to show all points
dx = 0.5
self.chart.axisX().setRange(min(xs) - dx, max(xs) + dx)
if self.ignore_outliers:
dy = np.ptp(ys) * 0.02
# Set Y scale to exclude outliers
q1, q3 = np.quantile(ys, (0.25, 0.75))
iqr = q3 - q1 # interquartile range
low = q1 - iqr * 1.5
high = q3 + iqr * 1.5
low = max(low, min(ys) - dy) # keep within range of data
high = min(high, max(ys) + dy)
else:
# Set Y scale to show all points
dy = np.ptp(ys) * 0.02
low = min(ys) - dy
high = max(ys) + dy
if self.log_scale:
low = max(low, 1e-5) # for log scale, low cannot be 0
self.chart.axisY().setRange(low, high)
else:
self.series[which].append(x, y)
def set_start_time(self, t0):
self.t0 = t0
self.is_running = True
def set_end(self):
self.is_running = False
def update_runtime(self):
if self.is_timer_running():
dt = time() - self.t0
dt_min, dt_sec = divmod(dt, 60)
title = f"Training Epoch <b>{self.epoch+1}</b> / "
title += f"Runtime: <b>{int(dt_min):02}:{int(dt_sec):02}</b>"
if self.last_epoch_val_loss is not None:
title += f"<br />Last Epoch Validation Loss: <b>{self.last_epoch_val_loss:.3e}</b>"
self.set_message(title)
def is_timer_running(self):
return self.t0 is not None and self.is_running
def set_message(self, text):
self.chart.setTitle(text)
[docs] def check_messages(
self, timeout=10, times_to_check: int = 10, do_update: bool = True
):
"""
Polls for ZMQ messages and adds any received data to graph.
The message is a dictionary encoded as JSON:
* event - options include
* train_begin
* train_end
* epoch_begin
* epoch_end
* batch_end
* what - this should match the type of model we're training and
ensures that we ignore old messages when we start monitoring
a new training session (when we're training multiple types
of models in a sequence, as for the top-down pipeline).
* logs - dictionary with data relevant for plotting, can include
* loss
* val_loss
"""
if self.sub and self.sub.poll(timeout, zmq.POLLIN):
msg = jsonpickle.decode(self.sub.recv_string())
# logger.info(msg)
if msg["event"] == "train_begin":
self.set_start_time(time())
self.current_job_output_type = msg["what"]
# make sure message matches current training job
if msg.get("what", "") == self.current_job_output_type:
if not self.is_timer_running():
# We must have missed the train_begin message, so start timer now
self.set_start_time(time())
if msg["event"] == "train_end":
self.set_end()
elif msg["event"] == "epoch_begin":
self.epoch = msg["epoch"]
elif msg["event"] == "epoch_end":
self.epoch_size = max(self.epoch_size, self.last_batch_number + 1)
self.add_datapoint(
(self.epoch + 1) * self.epoch_size,
msg["logs"]["loss"],
"epoch_loss",
)
if "val_loss" in msg["logs"].keys():
self.last_epoch_val_loss = msg["logs"]["val_loss"]
self.add_datapoint(
(self.epoch + 1) * self.epoch_size,
msg["logs"]["val_loss"],
"val_loss",
)
self.on_epoch.emit()
elif msg["event"] == "batch_end":
self.last_batch_number = msg["logs"]["batch"]
self.add_datapoint(
(self.epoch * self.epoch_size) + msg["logs"]["batch"],
msg["logs"]["loss"],
)
# Check for messages again (up to times_to_check times)
if times_to_check:
self.check_messages(
timeout=timeout, times_to_check=times_to_check - 1, do_update=False
)
if do_update:
self.update_runtime()
if __name__ == "__main__":
app = QtWidgets.QApplication([])
win = LossViewer()
win.show()
def test_point(x=[0]):
x[0] += 1
i = x[0] + 1
win.add_datapoint(i, i % 30 + 1)
t = QtCore.QTimer()
t.timeout.connect(test_point)
t.start(20)
win.set_message("Waiting for 3 seconds...")
t2 = QtCore.QTimer()
t2.timeout.connect(lambda: win.set_message("Running demo..."))
t2.start(3000)
app.exec_()