diff --git a/CHANGELOG.md b/CHANGELOG.md index f0a6735..0f6cafb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,29 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.6.1] - 2026-05-05 + +Internal cleanup pass and a few small breaking removals. For older Python or +older import paths, pin a previous `livelossplot` version. + +### Removed + +- The `livelossplot.keras` / `livelossplot.tf_keras` / `livelossplot.poutyne` / + `livelossplot.pytorch_ignite` import shortcuts (deprecated since 0.5.0). + Use `livelossplot.inputs.*` or `from livelossplot import PlotLossesKeras` etc. +- `PlotLosses.draw()` (deprecated since 0.5.0). Use `.send()`. + +### Changed + +- `ExtremaPrinter` parameter `massage_template` renamed to `message_template`. +- `Plot2d` parameter `valiation_data` renamed to `validation_data`. + +### Fixed + +- Mutable default arguments in `PlotLosses(outputs=...)` and `MatplotlibPlot(extra_plots=...)`. +- `MainLogger.groups = None` now correctly resets to `{}`. +- `ipython` declared as a runtime dependency. + ## [0.6.0] - 2026-05-04 ### Changed diff --git a/examples/2d_prediction_maps.ipynb b/examples/2d_prediction_maps.ipynb index 5f4eaff..5dc290f 100644 --- a/examples/2d_prediction_maps.ipynb +++ b/examples/2d_prediction_maps.ipynb @@ -233,7 +233,7 @@ "model = MLP(6)\n", "\n", "plot2d = matplotlib_subplots.Plot2d(model, X_train, y_train,\n", - " valiation_data=(X_test, y_test),\n", + " validation_data=(X_test, y_test),\n", " margin=0.2, h=0.02, device=device)\n", "plot2d.predict = plot2d._predict_pytorch\n", "liveloss = PlotLosses(outputs=[plot2d])\n", @@ -272,7 +272,7 @@ "model = MLP(3, activation=nn.Sigmoid())\n", "\n", "plot2d = matplotlib_subplots.Plot2d(model, X_train, y_train,\n", - " valiation_data=(X_test, y_test),\n", + " validation_data=(X_test, y_test),\n", " margin=0.2, h=0.02, device=device)\n", "plot2d.predict = plot2d._predict_pytorch\n", "liveloss = PlotLosses(outputs=[plot2d])\n", diff --git a/livelossplot/__init__.py b/livelossplot/__init__.py index 4d23204..37c5578 100644 --- a/livelossplot/__init__.py +++ b/livelossplot/__init__.py @@ -2,50 +2,20 @@ .. include:: ../README.md """ -import sys -import warnings -from importlib.util import find_spec - +from . import inputs, outputs +from .inputs import * from .main_logger import MainLogger from .plot_losses import PlotLosses -from . import inputs -from .inputs import * -from . import outputs from .version import __version__ -_input_plugin_dict = { - 'keras': 'Keras', - 'tf_keras': 'KerasTF', - 'pytorch_ignite': 'Ignite', - 'poutyne': 'Poutyne', -} - - -class OldDependenciesFinder: - """ - Data package module loader finder. This class sits on `sys.meta_path` and returns the - loader it knows for a given path, if it knows a compatible loader. - """ - @classmethod - def find_spec(self, fullname: str, *_, **__): - """This functions is what gets executed by the loader. - Args: - fullname: name of the called module - """ - parts = fullname.split('.') - if len(parts) == 2 and parts[0] == 'livelossplot' and parts[1] in _input_plugin_dict: - name = parts[1] - msg = 'livelossplot.{name} will be deprecated, please use livelossplot.inputs.{name}\n' - msg += 'or use callback directly: from livelossplot import PlotLosses{new_name}' - warnings.warn(msg.format(name=name, new_name=_input_plugin_dict[name]), DeprecationWarning) - fullname = 'livelossplot.inputs.{name}'.format(name=name) - return find_spec(fullname) - return None - - -sys.meta_path.append(OldDependenciesFinder()) - __all__ = [ - 'MainLogger', 'inputs', 'outputs', 'PlotLosses', 'PlotLossesKeras', 'PlotLossesKerasTF', 'PlotLossesIgnite', - 'PlotLossesPoutyne' + 'MainLogger', + 'PlotLosses', + 'PlotLossesIgnite', + 'PlotLossesKeras', + 'PlotLossesKerasTF', + 'PlotLossesPoutyne', + '__version__', + 'inputs', + 'outputs', ] diff --git a/livelossplot/inputs/generic_keras.py b/livelossplot/inputs/generic_keras.py index c5dd7e4..5882a89 100644 --- a/livelossplot/inputs/generic_keras.py +++ b/livelossplot/inputs/generic_keras.py @@ -1,4 +1,3 @@ -from typing import Dict from livelossplot.plot_losses import PlotLosses @@ -12,7 +11,7 @@ def __init__(self, **kwargs): """ self.liveplot = PlotLosses(**kwargs) - def on_epoch_end(self, epoch: int, logs: Dict[str, float]): + def on_epoch_end(self, epoch: int, logs: dict[str, float]): """Send metrics to livelossplot Args: epoch: epoch number diff --git a/livelossplot/inputs/poutyne.py b/livelossplot/inputs/poutyne.py index 81fa5b6..382b70a 100644 --- a/livelossplot/inputs/poutyne.py +++ b/livelossplot/inputs/poutyne.py @@ -1,6 +1,6 @@ -from typing import Dict from poutyne.framework import Callback + from ..plot_losses import PlotLosses @@ -11,7 +11,7 @@ def __init__(self, **kwargs): Args: **kwargs: keyword arguments that will be passed to PlotLosses constructor """ - super(PlotLossesCallback, self).__init__() + super().__init__() self.liveplot = PlotLosses(**kwargs) self.metrics = None @@ -21,7 +21,7 @@ def on_train_begin(self, logs): self.metrics = list(metrics) self.metrics += ['val_' + metric for metric in metrics] - def on_epoch_end(self, epoch: int, logs: Dict[str, float]): + def on_epoch_end(self, epoch: int, logs: dict[str, float]): """Send metrics to livelossplot Args: epoch: epoch number diff --git a/livelossplot/inputs/pytorch_ignite.py b/livelossplot/inputs/pytorch_ignite.py index ec02e0c..6ffe90e 100644 --- a/livelossplot/inputs/pytorch_ignite.py +++ b/livelossplot/inputs/pytorch_ignite.py @@ -1,12 +1,12 @@ -from typing import Optional import ignite.engine from ignite.handlers import global_step_from_engine + from livelossplot.plot_losses import PlotLosses class PlotLossesCallback: - def __init__(self, train_engine: Optional[ignite.engine.Engine] = None, **kwargs): + def __init__(self, train_engine: ignite.engine.Engine | None = None, **kwargs): """ Args: train_engine: engine with global step information, send metohod callback will be attached to it @@ -51,5 +51,5 @@ def store(self, engine: ignite.engine.Engine): if not self.train_engine: self.send() - def send(self, _: Optional[ignite.engine.Engine] = None): + def send(self, _: ignite.engine.Engine | None = None): self.liveplot.send() diff --git a/livelossplot/inputs/tf_keras.py b/livelossplot/inputs/tf_keras.py index 8e2d074..938e499 100644 --- a/livelossplot/inputs/tf_keras.py +++ b/livelossplot/inputs/tf_keras.py @@ -1,4 +1,5 @@ from tensorflow import keras + from .generic_keras import _PlotLossesCallback diff --git a/livelossplot/main_logger.py b/livelossplot/main_logger.py index f4d9aed..cd1c887 100644 --- a/livelossplot/main_logger.py +++ b/livelossplot/main_logger.py @@ -1,74 +1,66 @@ import re -from collections import OrderedDict, defaultdict -from typing import NamedTuple, Dict, Iterable, List, Pattern, Tuple, Optional, Union +from collections import defaultdict +from collections.abc import Iterable +from re import Pattern +from typing import NamedTuple + + +class LogItem(NamedTuple): + step: int + value: float + -# Value of metrics - for value later, we want to support numpy arrays etc -LogItem = NamedTuple('LogItem', [('step', int), ('value', float)]) COMMON_NAME_SHORTCUTS = { 'acc': 'Accuracy', 'nll': 'Log Loss (cost function)', 'mse': 'Mean Squared Error', - 'loss': 'Loss' + 'loss': 'Loss', } class MainLogger: - """ - Main logger - the aim of this class is to store every log from training - Log is a float value with corresponding training engine step - """ + """Stores per-step training metrics and groups them for plotting.""" + def __init__( self, - groups: Optional[Dict[str, List[str]]] = None, - metric_to_name: Optional[Dict[str, str]] = None, + groups: dict[str, list[str]] | None = None, + metric_to_name: dict[str, str] | None = None, from_step: int = 0, current_step: int = -1, auto_generate_groups_if_not_available: bool = True, auto_generate_metric_to_name: bool = True, - group_patterns: Iterable[Tuple[Pattern, str]] = ( + group_patterns: Iterable[tuple[str | Pattern, str]] = ( (r'^(?!val(_|-))(.*)', 'training'), (r'^(val(_|-))(.*)', 'validation'), ), - step_names: Union[str, Dict[str, str]] = 'epoch' + step_names: str | dict[str, str] = 'epoch', ): """ Args: - groups: dictionary with grouped metrics for example the group 'accuracy' can contains different stages - for example 'validation_accuracy', 'training_accuracy' etc. - metric_to_name: transformation of metric name which can be used to display name - we can have short name in the code (acc), but full name on charts (Accuracy) - from_step: step to show in plots (positive: show steps from this one, negative: show only this many last steps) - current_step: current step of the train loop - auto_generate_groups_if_not_available: flag, that enable auto-creation of metric groups - base on group patterns - auto_generate_metric_to_name: flag, that enable auto-creation of metric long names - base on common shortcuts - group_patterns: you can put there regular expressions to match a few metric names with group - and replace its name using second value - step_names: dictionary with a name of x axis for each metrics group or one name for all metrics + groups: pre-defined metric groups, e.g. {'accuracy': ['acc', 'val_acc']}. + metric_to_name: short-name to display-name overrides, e.g. {'acc': 'Accuracy'}. + from_step: positive = show steps from this one; negative = show only this many last steps. + current_step: current step of the train loop. + auto_generate_groups_if_not_available: derive groups from `group_patterns` if `groups` is empty. + auto_generate_metric_to_name: derive display names from `group_patterns` and `COMMON_NAME_SHORTCUTS`. + group_patterns: regex patterns mapping metric names to group labels (training/validation). + step_names: x-axis label, either a single string or per-group mapping. """ - self.log_history = {} - self.groups = groups if groups is not None else {} - self.metric_to_name = metric_to_name if metric_to_name else {} + self.log_history: dict[str, list[LogItem]] = {} + self.groups: dict[str, list[str]] = groups if groups is not None else {} + self.metric_to_name: dict[str, str] = metric_to_name or {} self.from_step = from_step self.current_step = current_step - self.auto_generate_groups = all((not groups, auto_generate_groups_if_not_available)) + self.auto_generate_groups = not groups and auto_generate_groups_if_not_available self.auto_generate_metric_to_name = auto_generate_metric_to_name self.group_patterns = tuple((re.compile(pattern), replace_with) for pattern, replace_with in group_patterns) if isinstance(step_names, str): - self.step_names = defaultdict(lambda: step_names) + self.step_names: dict[str, str] = defaultdict(lambda: step_names) else: self.step_names = defaultdict(lambda: 'epoch', step_names) - def update(self, logs: Dict[str, float], current_step: Optional[int] = None) -> None: - """ - Args: - logs: dictionary with metric names and values - current_step: current step of the training loop - - Notes: - Loop step can be controlled outside or inside main logger with autoincrement of self.current_step - """ + def update(self, logs: dict[str, float], current_step: int | None = None) -> None: + """Append a new step of metric values.""" if current_step is None: self.current_step += 1 current_step = self.current_step @@ -79,98 +71,69 @@ def update(self, logs: Dict[str, float], current_step: Optional[int] = None) -> self._add_new_metric(k) self.log_history[k].append(LogItem(step=current_step, value=v)) - def _add_new_metric(self, metric_name: str): - """Add empty list for a new metric and extend metric name transformations - Args: - metric_name: name of metric that will be added to log_history as empty list - """ - self.log_history[metric_name] = [] - if not self.metric_to_name.get(metric_name): - self._auto_generate_metrics_to_name(metric_name) - - def _auto_generate_metrics_to_name(self, metric_name: str): - """The function generate transforms for metric names base on patterns - Args: - metric_name: name of new appended metric + def reset(self) -> None: + """Clear all logs, groups, and reset the step counter.""" + self.log_history = {} + self.groups = {} + self.current_step = -1 - Example: - It can create transformation from val_acc to Validation Accuracy - """ - suffix = self._find_suffix_with_group_patterns(metric_name) - if suffix is None and suffix != metric_name: - return - similar_metric_names = [m for m in self.log_history.keys() if m.endswith(suffix)] - if len(similar_metric_names) == 1: - return - for name in similar_metric_names: - new_name = name - for pattern_to_replace, replace_with in self.group_patterns: - new_name = re.sub(pattern_to_replace, replace_with, new_name) - if suffix in COMMON_NAME_SHORTCUTS.keys(): - new_name = new_name.replace(suffix, COMMON_NAME_SHORTCUTS[suffix]) - self.metric_to_name[name] = new_name + def grouped_log_history( + self, + raw_names: bool = False, + raw_group_names: bool = False, + ) -> dict[str, dict[str, list[LogItem]]]: + """Return logs grouped by metric group, sorted alphabetically. - def grouped_log_history(self, - raw_names: bool = False, - raw_group_names: bool = False) -> Dict[str, Dict[str, List[LogItem]]]: - """ Args: - raw_names: flag, return raw names instead of transformed by metric to name (as in update() input dictionary) - raw_group_names: flag, return group names without transforming them with COMMON_NAME_SHORTCUTS - - Returns: - logs grouped by metric groups - groups are passed in the class constructor - - Notes: - method use group patterns instead of groups if they are available + raw_names: keep original metric names instead of `metric_to_name` mapping. + raw_group_names: keep original group names instead of `COMMON_NAME_SHORTCUTS` mapping. """ if self.auto_generate_groups: self.groups = self._auto_generate_groups() - ret = {} - sorted_groups = OrderedDict(sorted(self.groups.items(), key=lambda t: t[0])) - for group_name, names in sorted_groups.items(): - group_name = group_name if raw_group_names else COMMON_NAME_SHORTCUTS.get(group_name, group_name) - ret[group_name] = { + result: dict[str, dict[str, list[LogItem]]] = {} + for group_name, names in sorted(self.groups.items()): + display_group = group_name if raw_group_names else COMMON_NAME_SHORTCUTS.get(group_name, group_name) + result[display_group] = { name if raw_names else self.metric_to_name.get(name, name): self.history_shorter(name) for name in names } - return ret - - def history_shorter(self, name: str, full: bool = False) -> List[LogItem]: - """ - Args: - name: metrics name, e.g. 'val_acc' or 'loss' - full: flag, if True return all, otherwise as specified by the from_step parameter + return result - Returns: - a list of log items - """ + def history_shorter(self, name: str, full: bool = False) -> list[LogItem]: + """Return the log history for one metric, trimmed by `from_step` unless `full=True`.""" if name not in self.log_history: return [] log_metrics = self.log_history[name] if full or self.from_step == 0: return log_metrics - elif self.from_step > 0: + if self.from_step > 0: return [x for x in log_metrics if x.step >= self.from_step] - else: - current_from_step = self.current_step + self.from_step - return [x for x in log_metrics if x.step >= current_from_step] + threshold = self.current_step + self.from_step + return [x for x in log_metrics if x.step >= threshold] - def _auto_generate_groups(self) -> Dict[str, List[str]]: - """ - Returns: - groups generated with group patterns + def _add_new_metric(self, metric_name: str) -> None: + self.log_history[metric_name] = [] + if not self.metric_to_name.get(metric_name): + self._auto_generate_metrics_to_name(metric_name) - Notes: - Auto create groups base on val_ prefix - this step is skipped if groups are set - or if group patterns are available - """ - groups = {} - for key in self.log_history.keys(): + def _auto_generate_metrics_to_name(self, metric_name: str) -> None: + suffix = self._find_suffix_with_group_patterns(metric_name) + similar_metric_names = [m for m in self.log_history if m.endswith(suffix)] + if len(similar_metric_names) == 1: + return + for name in similar_metric_names: + new_name = name + for pattern, replace_with in self.group_patterns: + new_name = re.sub(pattern, replace_with, new_name) + if suffix in COMMON_NAME_SHORTCUTS: + new_name = new_name.replace(suffix, COMMON_NAME_SHORTCUTS[suffix]) + self.metric_to_name[name] = new_name + + def _auto_generate_groups(self) -> dict[str, list[str]]: + groups: dict[str, list[str]] = {} + for key in self.log_history: abs_key = self._find_suffix_with_group_patterns(key) - if not groups.get(abs_key): - groups[abs_key] = [] - groups[abs_key].append(key) + groups.setdefault(abs_key, []).append(key) return groups def _find_suffix_with_group_patterns(self, metric_name: str) -> str: @@ -180,33 +143,3 @@ def _find_suffix_with_group_patterns(self, metric_name: str) -> str: if match: suffix = match.groups()[-1] return suffix - - def reset(self) -> None: - """Method clears logs, groups and reset step counter""" - self.log_history = {} - self.groups = {} - self.current_step = -1 - - @property - def groups(self) -> Dict[str, List[str]]: - """groups getter""" - return self._groups - - @groups.setter - def groups(self, value: Dict[str, List[str]]) -> None: - """groups setter - groups should be dictionary""" - if value is None: - self._groups = {} - self._groups = value - - @property - def log_history(self) -> Dict[str, List[LogItem]]: - """logs getter""" - return self._log_history - - @log_history.setter - def log_history(self, value: Dict[str, List[LogItem]]) -> None: - """logs setter - logs can not be overwritten - you can only reset it to empty state""" - if len(value) > 0: - raise RuntimeError('Cannot overwrite log history with non empty dictionary') - self._log_history = value diff --git a/livelossplot/outputs/__init__.py b/livelossplot/outputs/__init__.py index 458dce5..2eda8f7 100644 --- a/livelossplot/outputs/__init__.py +++ b/livelossplot/outputs/__init__.py @@ -1,20 +1,17 @@ -# technical - -from .base_output import BaseOutput - -# default - -from .matplotlib_plot import MatplotlibPlot -from .extrema_printer import ExtremaPrinter - -# with external dependencies -# import are respective __init__ methods -# hack-ish, but works (and I am not aware of a more proper way to do so) - -from .bokeh_plot import BokehPlot -from .tensorboard_logger import TensorboardLogger -from .tensorboard_tf_logger import TensorboardTFLogger - -# with external dependencies - -from . import matplotlib_subplots +from . import matplotlib_subplots +from .base_output import BaseOutput +from .bokeh_plot import BokehPlot +from .extrema_printer import ExtremaPrinter +from .matplotlib_plot import MatplotlibPlot +from .tensorboard_logger import TensorboardLogger +from .tensorboard_tf_logger import TensorboardTFLogger + +__all__ = [ + "BaseOutput", + "BokehPlot", + "ExtremaPrinter", + "MatplotlibPlot", + "TensorboardLogger", + "TensorboardTFLogger", + "matplotlib_subplots", +] diff --git a/livelossplot/outputs/base_output.py b/livelossplot/outputs/base_output.py index 1215409..b39c687 100644 --- a/livelossplot/outputs/base_output.py +++ b/livelossplot/outputs/base_output.py @@ -1,27 +1,18 @@ from abc import ABC, abstractmethod +from typing import Literal from livelossplot.main_logger import MainLogger +OutputMode = Literal['notebook', 'script'] + class BaseOutput(ABC): @abstractmethod - def send(self, logger: MainLogger): - """Abstract method - handle logs for a plugin""" - ... - - def close(self): - """Overwrite it with last steps""" - ... + def send(self, logger: MainLogger) -> None: + """Handle logs for a plugin""" - def set_output_mode(self, mode: str): - """Some of output plugins needs to know target format""" - assert mode in ('notebook', 'script') - self._set_output_mode(mode) + def close(self) -> None: # noqa: B027 — optional override hook + """Run at the end of training, if needed""" - def _set_output_mode(self, mode: str): - """ - Args: - mode: mode for callbacks - some of outputs need to change some behaviors, - depending on the working environment (scripts and jupyter notebooks) - """ - ... + def set_output_mode(self, mode: OutputMode) -> None: # noqa: B027 — optional override hook + """Notify plugin whether we're in a notebook or script context""" diff --git a/livelossplot/outputs/bokeh_plot.py b/livelossplot/outputs/bokeh_plot.py index 105a540..f2c84a8 100644 --- a/livelossplot/outputs/bokeh_plot.py +++ b/livelossplot/outputs/bokeh_plot.py @@ -1,8 +1,7 @@ import sys -from typing import List, Dict, Tuple -from livelossplot.main_logger import MainLogger, LogItem -from livelossplot.outputs.base_output import BaseOutput +from livelossplot.main_logger import LogItem, MainLogger +from livelossplot.outputs.base_output import BaseOutput, OutputMode class BokehPlot(BaseOutput): @@ -11,7 +10,7 @@ def __init__( self, max_cols: int = 2, skip_first: int = 2, - cell_size: Tuple[int, int] = (400, 300), + cell_size: tuple[int, int] = (400, 300), output_file: str = './bokeh_output.html' ): """ @@ -21,7 +20,7 @@ def __init__( cell_size: size of one chart output_file: file to save the output """ - from bokeh import plotting, io, palettes + from bokeh import io, palettes, plotting self.plotting = plotting self.io = io self.plot_width, self.plot_height = cell_size @@ -42,7 +41,7 @@ def send(self, logger: MainLogger) -> None: log_groups = logger.grouped_log_history() new_grid_plot = False - for idx, (group_name, group_logs) in enumerate(log_groups.items(), start=1): + for _idx, (group_name, group_logs) in enumerate(log_groups.items(), start=1): fig = self.figures.get(group_name) if not fig: fig = self.plotting.figure(title=group_name) @@ -95,7 +94,7 @@ def _send_colab(self, logger: MainLogger) -> None: else: self._colab_handle.update(payload) - def _draw_metric_subplot(self, fig, group_logs: Dict[str, List[LogItem]]): + def _draw_metric_subplot(self, fig, group_logs: dict[str, list[LogItem]]): """ Args: fig: bokeh Figure @@ -144,12 +143,9 @@ def _create_grid_plot(self): ) self.target = self.plotting.show(self.grid, notebook_handle=self.is_notebook) - def _set_output_mode(self, mode: str): - """Set notebook or script mode""" + def set_output_mode(self, mode: OutputMode) -> None: self.is_notebook = mode == 'notebook' if self.is_notebook: - # Bokeh auto-detects Colab in recent versions; passing - # notebook_type='colab' explicitly raises in Bokeh 3.x. self.io.output_notebook() else: self.io.output_file(self.output_file) diff --git a/livelossplot/outputs/extrema_printer.py b/livelossplot/outputs/extrema_printer.py index a1a7004..8376840 100644 --- a/livelossplot/outputs/extrema_printer.py +++ b/livelossplot/outputs/extrema_printer.py @@ -1,44 +1,41 @@ -from typing import Dict, List +from livelossplot.main_logger import LogItem, MainLogger -from livelossplot.main_logger import LogItem -from livelossplot.main_logger import MainLogger from .base_output import BaseOutput class ExtremaPrinter(BaseOutput): def __init__( self, - massage_template: str = '\t{metric_name:16} \t (min: {min:8.3f},' + message_template: str = '\t{metric_name:16} \t (min: {min:8.3f},' ' max: {max:8.3f}, cur: {current:8.3f})' ): """ Args: - massage_template: you can specify massage which use all or a few values (min, max, current) + message_template: format string with min, max, current placeholders """ - self.massage_template = massage_template + self.message_template = message_template self.last_message = "" def send(self, logger: MainLogger): - """Create massages with log_history and massage template""" + """Print min/max/current per metric to stdout""" log_groups = logger.grouped_log_history() - self.last_message = '\n'.join(self._create_massages(log_groups)) + self.last_message = '\n'.join(self._create_messages(log_groups)) print(self.last_message) - def _create_massages(self, log_groups: Dict[str, Dict[str, List[LogItem]]]) -> List[str]: - """Create massages""" - massages = [] + def _create_messages(self, log_groups: dict[str, dict[str, list[LogItem]]]) -> list[str]: + messages = [] for group_name, group_logs in log_groups.items(): - massages.append(group_name) + messages.append(group_name) for metric_name, log_items in group_logs.items(): if len(log_items) == 0: - msg = '\t{metric_name:16} \t (no values!)'.format(metric_name=metric_name) + msg = f'\t{metric_name:16} \t (no values!)' else: values = [log_item.value for log_item in log_items] - min_val = min(values) - max_val = max(values) - current_val = values[-1] - msg = self.massage_template.format( - metric_name=metric_name, min=min_val, max=max_val, current=current_val + msg = self.message_template.format( + metric_name=metric_name, + min=min(values), + max=max(values), + current=values[-1], ) - massages.append(msg) - return massages + messages.append(msg) + return messages diff --git a/livelossplot/outputs/matplotlib_plot.py b/livelossplot/outputs/matplotlib_plot.py index a7025d6..c365f91 100644 --- a/livelossplot/outputs/matplotlib_plot.py +++ b/livelossplot/outputs/matplotlib_plot.py @@ -1,30 +1,27 @@ import math -from typing import Tuple, List, Dict, Optional, Callable, Literal +from collections.abc import Callable -import warnings - -import numpy as np -import matplotlib import matplotlib.pyplot as plt +import numpy as np from IPython.display import clear_output -from livelossplot.main_logger import MainLogger, LogItem -from livelossplot.outputs.base_output import BaseOutput + +from livelossplot.main_logger import LogItem, MainLogger +from livelossplot.outputs.base_output import BaseOutput, OutputMode class MatplotlibPlot(BaseOutput): - """NOTE: Removed figsize and dynamix_x_axis.""" def __init__( self, - cell_size: Tuple[int, int] = (6, 4), + cell_size: tuple[int, int] = (6, 4), max_cols: int = 2, - max_epoch: int = None, + max_epoch: int | None = None, skip_first: int = 2, - extra_plots: List[Callable[[MainLogger], None]] = [], - figpath: Optional[str] = None, - after_subplot: Optional[Callable[[plt.Axes, str, str], None]] = None, - before_plots: Optional[Callable[[plt.Figure, np.ndarray, int], None]] = None, - after_plots: Optional[Callable[[plt.Figure], None]] = None, - figsize: Optional[Tuple[int, int]] = None, + extra_plots: list[Callable[[plt.Axes, MainLogger], None]] | None = None, + figpath: str | None = None, + after_subplot: Callable[[plt.Axes, str, str], None] | None = None, + before_plots: Callable[[plt.Figure, np.ndarray, int], None] | None = None, + after_plots: Callable[[plt.Figure], None] | None = None, + figsize: tuple[int, int] | None = None, ): """ Args: @@ -41,16 +38,16 @@ def __init__( """ self.cell_size = cell_size self.max_cols = max_cols - self.skip_first = skip_first # think about it - self.extra_plots = extra_plots + self.skip_first = skip_first + self.extra_plots = extra_plots if extra_plots is not None else [] self.max_epoch = max_epoch self.figpath = figpath - self.file_idx = 0 # now only for saving files - self._after_subplot = after_subplot if after_subplot else self._default_after_subplot - self._before_plots = before_plots if before_plots else self._default_before_plots - self._after_plots = after_plots if after_plots else self._default_after_plots + self.file_idx = 0 + self._after_subplot = after_subplot or self._default_after_subplot + self._before_plots = before_plots or self._default_before_plots + self._after_plots = after_plots or self._default_after_plots self.figsize = figsize - self.output_mode: Literal['notebook', 'script'] = "notebook" + self.output_mode: OutputMode = "notebook" def send(self, logger: MainLogger): """Draw figures with metrics and show""" @@ -119,7 +116,7 @@ def _default_after_plots(self, fig: plt.Figure): else: plt.show() - def _draw_metric_subplot(self, ax: plt.Axes, group_logs: Dict[str, List[LogItem]], group_name: str, x_label: str): + def _draw_metric_subplot(self, ax: plt.Axes, group_logs: dict[str, list[LogItem]], group_name: str, x_label: str): """ Args: ax: matplotlib Axes @@ -139,13 +136,5 @@ def _draw_metric_subplot(self, ax: plt.Axes, group_logs: Dict[str, List[LogItem] self._after_subplot(ax, group_name, x_label) - def _not_inline_warning(self): - backend = matplotlib.get_backend() - if "backend_inline" not in backend: - warnings.warn( - "livelossplot requires inline plots.\nYour current backend is: {}" - "\nRun in a Jupyter environment and execute '%matplotlib inline'.".format(backend) - ) - - def _set_output_mode(self, mode: Literal['notebook', 'script']): + def set_output_mode(self, mode: OutputMode) -> None: self.output_mode = mode diff --git a/livelossplot/outputs/matplotlib_subplots.py b/livelossplot/outputs/matplotlib_subplots.py index b11adc8..f2f81ae 100644 --- a/livelossplot/outputs/matplotlib_subplots.py +++ b/livelossplot/outputs/matplotlib_subplots.py @@ -1,13 +1,14 @@ -from typing import Literal, Optional - import matplotlib.pyplot as plt import numpy as np +from IPython.display import clear_output from matplotlib.colors import ListedColormap +from livelossplot.outputs.base_output import OutputMode + class BaseSubplot: def __init__(self): - self.output_mode: Literal['notebook', 'script'] = 'notebook' + self.output_mode: OutputMode = 'notebook' def draw(self, *args, **kwargs): raise NotImplementedError @@ -15,70 +16,18 @@ def draw(self, *args, **kwargs): def __call__(self, *args, **kwargs): self.draw(*args, **kwargs) - def set_output_mode(self, mode: Literal['notebook', 'script']): + def set_output_mode(self, mode: OutputMode) -> None: self.output_mode = mode - def _present(self, fig: plt.Figure): - """Render fig appropriately for the current output mode.""" + def _present(self, fig: plt.Figure) -> None: if self.output_mode == 'notebook': - try: - from IPython.display import clear_output - clear_output(wait=True) - except ImportError: - pass + clear_output(wait=True) plt.show() else: plt.draw() plt.pause(0.05) -class LossSubplot(BaseSubplot): - """To rewrire, this one now won't work""" - def __init__( - self, metric, title="", series_fmt={ - 'training': '{}', - 'validation': 'val_{}' - }, skip_first=2, max_epoch=None - ): - super().__init__() - self.metric = metric - self.title = title - self.series_fmt = series_fmt - self.skip_first = skip_first - self.max_epoch = max_epoch - raise NotImplementedError() - - def _how_many_to_skip(self, log_length, skip_first): - if log_length < skip_first: - return 0 - elif log_length < 2 * skip_first: - return log_length - skip_first - else: - return skip_first - - def draw(self, logs): - skip = self._how_many_to_skip(len(logs), self.skip_first) - - if self.max_epoch is not None: - plt.xlim(1 + skip, self.max_epoch) - - for serie_label, serie_fmt in self.series_fmt.items(): - - serie_metric_name = serie_fmt.format(self.metric) - serie_metric_logs = [ - (log.get('_i', i + 1), log[serie_metric_name]) - for i, log in enumerate(logs[skip:]) if serie_metric_name in log - ] - - if len(serie_metric_logs) > 0: - xs, ys = zip(*serie_metric_logs) - plt.plot(xs, ys, label=serie_label) - - plt.title(self.title) - plt.xlabel('epoch') - plt.legend(loc='center right') - - class Plot1D(BaseSubplot): def __init__(self, model, X, Y): super().__init__() @@ -87,7 +36,6 @@ def __init__(self, model, X, Y): self.Y = Y def predict(self, model, X): - # e.g. model(torch.fromnumpy(X)).detach().numpy() return model.predict(X) def draw(self, *args, **kwargs): @@ -98,37 +46,27 @@ def draw(self, *args, **kwargs): class Plot2d(BaseSubplot): - def __init__(self, model, X, Y, valiation_data=(None, None), h=0.02, margin=0.25, device='cpu'): + def __init__(self, model, X, Y, validation_data=(None, None), h=0.02, margin=0.25): super().__init__() self.model = model self.X = X self.Y = Y - self.X_test, self.Y_test = valiation_data + self.X_test, self.Y_test = validation_data self.cm_bg = plt.cm.RdBu self.cm_points = ListedColormap(['#FF0000', '#0000FF']) x_min = X[:, 0].min() - margin x_max = X[:, 0].max() + margin - y_min = X[:, 1].min() - margin y_max = X[:, 1].max() + margin - self.xx, self.yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h)) - self.torch_device = device - - self._fig: Optional[plt.Figure] = None - self._ax: Optional[plt.Axes] = None - - def _predict_pytorch(self, model, x_numpy): - import torch - x = torch.from_numpy(x_numpy).to(self.torch_device).float() - return model(x).softmax(dim=1).detach().cpu().numpy() + self._fig: plt.Figure | None = None + self._ax: plt.Axes | None = None def predict(self, model, X): - # e.g. model(torch.fromnumpy(X)).detach().numpy() return model.predict(X) def send(self, logger): @@ -139,7 +77,7 @@ def send(self, logger): Z = self.predict(self.model, np.c_[self.xx.ravel(), self.yy.ravel()])[:, 1] Z = Z.reshape(self.xx.shape) - self._ax.contourf(self.xx, self.yy, Z, cmap=self.cm_bg, alpha=.8) + self._ax.contourf(self.xx, self.yy, Z, cmap=self.cm_bg, alpha=0.8) self._ax.scatter(self.X[:, 0], self.X[:, 1], c=self.Y, cmap=self.cm_points) if self.X_test is not None: self._ax.scatter(self.X_test[:, 0], self.X_test[:, 1], c=self.Y_test, cmap=self.cm_points, alpha=0.3) diff --git a/livelossplot/outputs/tensorboard_logger.py b/livelossplot/outputs/tensorboard_logger.py index 173fec3..6c1d727 100644 --- a/livelossplot/outputs/tensorboard_logger.py +++ b/livelossplot/outputs/tensorboard_logger.py @@ -1,44 +1,35 @@ -from datetime import datetime -from os import path -from typing import Optional +from datetime import datetime, timezone +from pathlib import Path from livelossplot.main_logger import MainLogger from livelossplot.outputs.base_output import BaseOutput class TensorboardLogger(BaseOutput): - """ - Class write logs to TensorBoard (using pure TensorBoard, not one from TensorFlow). - """ - def __init__(self, logdir: str = "./tensorboard_logs/", run_id: Optional[str] = None): + """Write logs to TensorBoard via the standalone `tensorboard` package.""" + + def __init__(self, logdir: str | Path = "./tensorboard_logs/", run_id: str | None = None): """ Args: - logdir: dir where TensorBoard events will be written - run_id: name for log id, otherwise it usses datetime + logdir: dir where TensorBoard events will be written. + run_id: name for the run; defaults to the current timestamp. """ from tensorboard import summary self.summary = summary - run_id = datetime.now().isoformat()[:-7].replace("T", " ").replace(":", "_") if run_id is None else run_id - self._path = path.join(logdir, run_id) - self.writer = summary.create_file_writer(self._path) + if run_id is None: + run_id = datetime.now(tz=timezone.utc).strftime('%Y-%m-%d %H_%M_%S') + self._path = Path(logdir) / run_id + self.writer = summary.create_file_writer(str(self._path)) - def close(self): - """Close tensorboard writer""" + def close(self) -> None: self.writer.close() - def log_scalar(self, name: str, value: float, global_step: float): - """ - Args: - name: name of metric - value: float value of metric - global_step: current step of the training loop - """ + def log_scalar(self, name: str, value: float, global_step: float) -> None: with self.writer.as_default(): self.summary.scalar(name, value, step=global_step) self.writer.flush() - def send(self, logger: MainLogger): - """Take log history from logger and store it in tensorboard event""" + def send(self, logger: MainLogger) -> None: for name, log_items in logger.log_history.items(): - last_log_item = log_items[-1] - self.log_scalar(name, last_log_item.value, last_log_item.step) + last = log_items[-1] + self.log_scalar(name, last.value, last.step) diff --git a/livelossplot/outputs/tensorboard_tf_logger.py b/livelossplot/outputs/tensorboard_tf_logger.py index 1ecb282..1f62e34 100644 --- a/livelossplot/outputs/tensorboard_tf_logger.py +++ b/livelossplot/outputs/tensorboard_tf_logger.py @@ -1,43 +1,35 @@ -from datetime import datetime -from os import path +from datetime import datetime, timezone +from pathlib import Path from livelossplot.main_logger import MainLogger from livelossplot.outputs.base_output import BaseOutput class TensorboardTFLogger(BaseOutput): - """ - Class write logs to TensorBoard (from TensorFlow). - """ - def __init__(self, logdir="./tensorboard_logs/", run_id=None): + """Write logs to TensorBoard via TensorFlow's bundled `tf.summary`.""" + + def __init__(self, logdir: str | Path = "./tensorboard_logs/", run_id: str | None = None): """ Args: - logdir: dir where TensorBoard events will be written - run_id: name for log id, otherwise it usses datetime + logdir: dir where TensorBoard events will be written. + run_id: name for the run; defaults to the current timestamp. """ from tensorflow import summary self.summary = summary - run_id = datetime.now().isoformat()[:-7].replace("T", " ").replace(":", "_") if run_id is None else run_id - self._path = path.join(logdir, run_id) - self.writer = summary.create_file_writer(self._path) + if run_id is None: + run_id = datetime.now(tz=timezone.utc).strftime('%Y-%m-%d %H_%M_%S') + self._path = Path(logdir) / run_id + self.writer = summary.create_file_writer(str(self._path)) - def close(self): - """Close tensorboard writer""" + def close(self) -> None: self.writer.close() - def log_scalar(self, name: str, value: float, global_step: int): - """ - Args: - name: name of metric - value: float value of metric - global_step: current step of the training loop - """ + def log_scalar(self, name: str, value: float, global_step: int) -> None: with self.writer.as_default(): self.summary.scalar(name, value, step=global_step) self.writer.flush() - def send(self, logger: MainLogger): - """Take log history from logger and store it in tensorboard event""" + def send(self, logger: MainLogger) -> None: for name, log_items in logger.log_history.items(): - last_log_item = log_items[-1] - self.log_scalar(name, last_log_item.value, last_log_item.step) + last = log_items[-1] + self.log_scalar(name, last.value, last.step) diff --git a/livelossplot/plot_losses.py b/livelossplot/plot_losses.py index 06d5b5e..b8db2e8 100644 --- a/livelossplot/plot_losses.py +++ b/livelossplot/plot_losses.py @@ -1,52 +1,47 @@ -import warnings -from typing import Type, TypeVar, List, Union, Optional, Tuple, Literal +from typing import TypeVar import livelossplot -from livelossplot.main_logger import MainLogger from livelossplot import outputs +from livelossplot.main_logger import MainLogger +from livelossplot.outputs.base_output import OutputMode from livelossplot.outputs.matplotlib_plot import MatplotlibPlot BO = TypeVar('BO', bound=outputs.BaseOutput) -def get_mode() -> Literal['notebook', 'script']: +def get_mode() -> OutputMode: + """Detect whether we're running in a notebook (Jupyter, Colab) or a script.""" try: from IPython import get_ipython - ipython = get_ipython() - if ipython is None: - return 'script' - name = ipython.__class__.__name__ - if name == "ZMQInteractiveShell" or name == "Shell": - # Shell is in Colab - return "notebook" - elif name == "TerminalInteractiveShell": - return "script" - print(f"Unknown IPython mode: {name}. Assuming notebook mode.") - return "notebook" except ImportError: - return "script" + return 'script' + ipython = get_ipython() + if ipython is None: + return 'script' + # ZMQInteractiveShell: Jupyter; Shell: Colab; TerminalInteractiveShell: ipython REPL. + return 'notebook' if ipython.__class__.__name__ in {'ZMQInteractiveShell', 'Shell'} else 'script' class PlotLosses: - """ - Class collect metrics from the training engine and send it to plugins, when send is called - """ + """Collect training metrics and dispatch them to one or more output plugins.""" + def __init__( self, - outputs: List[Union[Type[BO], str]] = ['MatplotlibPlot', 'ExtremaPrinter'], - mode: Optional[Literal['notebook', 'script']] = None, - figsize: Optional[Tuple[int, int]] = None, - **kwargs + outputs: list[type[BO] | str] | None = None, + mode: OutputMode | None = None, + figsize: tuple[int, int] | None = None, + **kwargs, ): """ Args: - outputs: list of output modules: objects inheriting from BaseOutput - or strings for livelossplot built-in output methods with default parameters - mode: Options: 'notebook' or 'script' - some of outputs need to change some behaviors, - depending on the working environment - figsize: tuple of (width, height) in inches for the figure - **kwargs: key-arguments which are passed to MainLogger constructor + outputs: list of output instances (subclasses of `BaseOutput`) or string names of built-in + outputs to instantiate with defaults. Defaults to `['MatplotlibPlot', 'ExtremaPrinter']`. + mode: 'notebook' or 'script'. Auto-detected if `None`. + figsize: `(width, height)` in inches; applies to `MatplotlibPlot` outputs. + **kwargs: forwarded to `MainLogger`. """ + if outputs is None: + outputs = ['MatplotlibPlot', 'ExtremaPrinter'] self.logger = MainLogger(**kwargs) self.outputs = [getattr(livelossplot.outputs, out)() if isinstance(out, str) else out for out in outputs] if mode is None: @@ -54,88 +49,43 @@ def __init__( for out in self.outputs: out.set_output_mode(mode) if figsize is not None and isinstance(out, MatplotlibPlot): - print(f"Setting figsize to {figsize}") out.figsize = figsize - def update(self, *args, **kwargs): - """update logs with arguments that will be passed to main logger""" + def update(self, *args, **kwargs) -> None: + """Forward to `MainLogger.update`.""" self.logger.update(*args, **kwargs) - def send(self): - """Method will send logs to every output class""" + def send(self) -> None: + """Send the current logs to every output plugin.""" for output in self.outputs: output.send(self.logger) - def draw(self): - """Send method substitute from old livelossplot api""" - warnings.warn('draw will be deprecated, please use send method', PendingDeprecationWarning) - self.send() - def reset_outputs(self) -> 'PlotLosses': - """Resets all outputs. - - Returns: - Plotlosses object (so it works for chaining) - """ + """Drop all outputs (chainable).""" self.outputs = [] return self def to_matplotlib(self, **kwargs) -> 'PlotLosses': - """Appends outputs.MatplotlibPlot output, with specified parameters. - - Args: - **kwargs: keyword arguments for MatplotlibPlot - - Returns: - Plotlosses object (so it works for chaining) - """ + """Append a `MatplotlibPlot` output (chainable).""" self.outputs.append(outputs.MatplotlibPlot(**kwargs)) return self def to_extrema_printer(self, **kwargs) -> 'PlotLosses': - """Appends outputs.ExtremaPrinter output, with specified parameters. - - Args: - **kwargs: keyword arguments for ExtremaPrinter - - Returns: - Plotlosses object (so it works for chaining) - """ + """Append an `ExtremaPrinter` output (chainable).""" self.outputs.append(outputs.ExtremaPrinter(**kwargs)) return self def to_bokeh(self, **kwargs) -> 'PlotLosses': - """Appends outputs.BokehPlot output, with specified parameters. - - Args: - **kwargs: keyword arguments for BokehPlot - - Returns: - Plotlosses object (so it works for chaining) - """ + """Append a `BokehPlot` output (chainable).""" self.outputs.append(outputs.BokehPlot(**kwargs)) return self def to_tensorboard(self, **kwargs) -> 'PlotLosses': - """Appends outputs.TensorboardLogger output, with specified parameters. - - Args: - **kwargs: keyword arguments for TensorboardLogger - - Returns: - Plotlosses object (so it works for chaining) - """ + """Append a `TensorboardLogger` output (chainable).""" self.outputs.append(outputs.TensorboardLogger(**kwargs)) return self def to_tensorboard_tf(self, **kwargs) -> 'PlotLosses': - """Appends outputs.TensorboardTFLogger output, with specified parameters. - - Args: - **kwargs: keyword arguments for TensorboardTFLogger - - Returns: - Plotlosses object (so it works for chaining) - """ + """Append a `TensorboardTFLogger` output (chainable).""" self.outputs.append(outputs.TensorboardTFLogger(**kwargs)) return self diff --git a/pyproject.toml b/pyproject.toml index 45055bb..e631b02 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "livelossplot" -version = "0.6.0" +version = "0.6.1" description = "Live training loss plot in Jupyter Notebook for Keras, PyTorch and others." readme = "README.md" license = "MIT" @@ -30,8 +30,9 @@ classifiers = [ "Programming Language :: Python :: 3.13", ] dependencies = [ - "matplotlib", - "bokeh", + "matplotlib>=3.6", + "bokeh>=3.0", + "ipython>=8.0", ] [project.urls] @@ -68,11 +69,12 @@ target-version = "py310" extend-exclude = ["examples", "build", "dist", "scripts"] [tool.ruff.lint] -# Conservative selection — no auto-rewrites of import order or syntax, -# so existing files stay intact. Tighten in a follow-up PR. -select = ["E", "F", "W"] +select = ["E", "F", "W", "I", "UP", "B", "SIM", "RET"] ignore = [ - "E501", # line length is enforced by the formatter; soft-ignore here + "E501", + "B008", + "SIM102", + "SIM108", ] [tool.ruff.lint.per-file-ignores] diff --git a/tests/external_test_examples.py b/tests/external_test_examples.py index dacd9e2..9d2481c 100644 --- a/tests/external_test_examples.py +++ b/tests/external_test_examples.py @@ -21,9 +21,9 @@ def run_notebook(notebook_path): proc.allow_errors = True proc.preprocess(nb, {'metadata': {'path': 'examples/'}}) - output_path = os.path.join(dirname, '_test_{}.ipynb'.format(nb_name)) + output_path = os.path.join(dirname, f'_test_{nb_name}.ipynb') - with open(output_path, mode='wt') as f: + with open(output_path, mode='w') as f: nbformat.write(nb, f) errors = [] for cell in nb.cells: diff --git a/tests/external_test_keras.py b/tests/external_test_keras.py index 3c051cd..6eafb56 100644 --- a/tests/external_test_keras.py +++ b/tests/external_test_keras.py @@ -2,8 +2,7 @@ from keras import Sequential from keras.layers import LSTM, Dense -from numpy import argmax -from numpy import array +from numpy import argmax, array from livelossplot import MainLogger, PlotLossesKeras from livelossplot.outputs import BaseOutput diff --git a/tests/external_test_poutyne.py b/tests/external_test_poutyne.py index 0ff57b1..e4110dd 100644 --- a/tests/external_test_poutyne.py +++ b/tests/external_test_poutyne.py @@ -1,7 +1,7 @@ import torch -from torch import nn, optim -from torch.utils.data import TensorDataset, DataLoader from poutyne import Model +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset from livelossplot import MainLogger, PlotLossesPoutyne from livelossplot.outputs import BaseOutput @@ -24,8 +24,7 @@ def get_random_data(): inputs = torch.rand(dataset_size, num_inputs) labels = torch.randint(num_outputs, (dataset_size, )) dataset = TensorDataset(inputs, labels) - dataloader = DataLoader(dataset, batch_size=10) - return dataloader + return DataLoader(dataset, batch_size=10) def test_poutyne(): diff --git a/tests/external_test_pytorch_ignite.py b/tests/external_test_pytorch_ignite.py index 96cb930..08b7d65 100644 --- a/tests/external_test_pytorch_ignite.py +++ b/tests/external_test_pytorch_ignite.py @@ -1,7 +1,7 @@ import torch from ignite import engine from torch import nn, optim -from torch.utils.data import TensorDataset, DataLoader +from torch.utils.data import DataLoader, TensorDataset from livelossplot import MainLogger, PlotLossesIgnite from livelossplot.outputs import BaseOutput @@ -24,8 +24,7 @@ def get_random_data(): inputs = torch.rand(dataset_size, num_inputs) labels = torch.randint(num_outputs, (dataset_size, )) dataset = TensorDataset(inputs, labels) - dataloader = DataLoader(dataset, batch_size=10) - return dataloader + return DataLoader(dataset, batch_size=10) def test_ignite(): diff --git a/tests/test_main_logger.py b/tests/test_main_logger.py index 8e8e8a6..4f67b11 100644 --- a/tests/test_main_logger.py +++ b/tests/test_main_logger.py @@ -6,7 +6,7 @@ def test_main_logger(): logger = MainLogger() logs = {'loss': 0.6} logger.update(logs) - assert 'loss' in logger.log_history.keys() + assert 'loss' in logger.log_history assert len(logger.log_history['loss']) == 1 @@ -72,7 +72,7 @@ def test_main_logger_autogroups(): grouped_log_history = logger.grouped_log_history() target_groups = {'Accuracy': ('validation', 'training'), 'Loss': ('validation', 'training'), 'lr': ('lr', )} for target_group, target_metrics in target_groups.items(): - for m1, m2 in zip(sorted(grouped_log_history[target_group].keys()), sorted(target_metrics)): + for m1, m2 in zip(sorted(grouped_log_history[target_group].keys()), sorted(target_metrics), strict=False): assert m1 == m2 diff --git a/uv.lock b/uv.lock index 88b3167..08a61fe 100644 --- a/uv.lock +++ b/uv.lock @@ -1295,10 +1295,12 @@ wheels = [ [[package]] name = "livelossplot" -version = "0.6.0" +version = "0.6.1" source = { editable = "." } dependencies = [ { name = "bokeh" }, + { name = "ipython", version = "8.39.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "ipython", version = "9.13.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "matplotlib" }, ] @@ -1315,8 +1317,9 @@ docs = [ [package.metadata] requires-dist = [ - { name = "bokeh" }, - { name = "matplotlib" }, + { name = "bokeh", specifier = ">=3.0" }, + { name = "ipython", specifier = ">=8.0" }, + { name = "matplotlib", specifier = ">=3.6" }, ] [package.metadata.requires-dev]