examples.example_01.training_logging

Training-time logging helpers for Example 01.

  1"""Training-time logging helpers for Example 01."""
  2
  3from __future__ import annotations
  4
  5import json
  6import math
  7import os
  8from dataclasses import dataclass, field
  9from html import escape
 10from pathlib import Path
 11from typing import Any
 12
 13import numpy as np
 14from PIL import Image, ImageDraw, ImageFont
 15
 16from oak.types import EpisodeTrace
 17
 18_DEFAULT_ANIMATION_OUTPUT_DIR = Path("tests/results/animations")
 19_DEFAULT_ANIMATION_EVERY = 100
 20_DEFAULT_ANIMATION_LAST = 1
 21_DEFAULT_ANIMATION_FPS = 30
 22_DEFAULT_ANIMATION_MAX_FRAMES = 400
 23_DEFAULT_CURVE_OUTPUT_DIR = Path("tests/results/training_curves")
 24_DEFAULT_CURVE_CHECKPOINT_EVERY = 10
 25_DEFAULT_SERIES_OPACITY = 1.0
 26_RAW_SERIES_OPACITY = 0.25
 27_RAW_SERIES_WIDTH = 1.5
 28_AVG_SERIES_WIDTH = 2.75
 29
 30
 31def _moving_average(values: list[float], window: int) -> list[float]:
 32    if not values:
 33        return []
 34    window = max(window, 1)
 35    averaged: list[float] = []
 36    for index, value in enumerate(values):
 37        del value
 38        start = max(0, index - window + 1)
 39        window_values = [
 40            float(candidate)
 41            for candidate in values[start : index + 1]
 42            if math.isfinite(float(candidate))
 43        ]
 44        if not window_values:
 45            averaged.append(float("nan"))
 46        else:
 47            averaged.append(sum(window_values) / len(window_values))
 48    return averaged
 49
 50
 51def _episode_metric_snapshot(agent: Any) -> dict[str, float]:
 52    metrics: dict[str, float] = {}
 53    for component_name in (
 54        "perception",
 55        "value_function",
 56        "reactive_policy",
 57        "transition_model",
 58    ):
 59        component = getattr(agent, component_name, None)
 60        metric_fn = getattr(component, "training_metrics", None)
 61        if callable(metric_fn):
 62            for key, value in dict(metric_fn()).items():
 63                if math.isfinite(float(value)):
 64                    metrics[key] = float(value)
 65    return metrics
 66
 67
 68@dataclass(slots=True, frozen=True)
 69class EpisodeCaptureSchedule:
 70    """Select which training episodes should produce artifacts."""
 71
 72    episode_indices: tuple[int, ...] = ()
 73    every_n_episodes: int | None = None
 74    last_n_episodes: int = 0
 75
 76    def should_capture(self, episode: int, total_episodes: int) -> bool:
 77        if episode in self.episode_indices:
 78            return True
 79        if self.every_n_episodes is not None and self.every_n_episodes > 0:
 80            if episode % self.every_n_episodes == 0:
 81                return True
 82        if self.last_n_episodes > 0 and episode >= max(total_episodes - self.last_n_episodes, 0):
 83            return True
 84        return False
 85
 86
 87@dataclass(slots=True)
 88class EpisodeAnimationRecorder:
 89    """Save traced episode frames as GIF animations plus compact metadata."""
 90
 91    output_dir: Path
 92    schedule: EpisodeCaptureSchedule = field(default_factory=EpisodeCaptureSchedule)
 93    prefix: str = "training"
 94    fps: int = 30
 95    max_frames: int | None = None
 96    save_metadata: bool = True
 97
 98    def should_capture(self, episode: int, total_episodes: int) -> bool:
 99        return self.schedule.should_capture(episode, total_episodes)
100
101    def __call__(self, trace: EpisodeTrace[Any, Any, Any, Any]) -> None:
102        self.output_dir.mkdir(parents=True, exist_ok=True)
103        stem = self._episode_stem(trace)
104        metadata_path = self.output_dir / f"{stem}.json"
105        gif_path = self.output_dir / f"{stem}.gif"
106
107        frames = list(trace.frames)
108        if self.max_frames is not None and self.max_frames > 0:
109            frames = frames[: self.max_frames]
110        if not frames:
111            if self.save_metadata:
112                metadata = self._trace_metadata(trace)
113                metadata["warning"] = (
114                    "No frames were captured for this episode. Ensure the world was "
115                    "created with render_mode='rgb_array' and the environment's "
116                    "rendering dependencies are installed."
117                )
118                metadata_path.write_text(json.dumps(metadata, indent=2) + "\n")
119            return
120
121        pil_frames = [self._to_pil_image(frame) for frame in frames]
122        pil_frames[0].save(
123            gif_path,
124            save_all=True,
125            append_images=pil_frames[1:],
126            duration=max(int(1000 / max(self.fps, 1)), 1),
127            loop=0,
128        )
129
130        if self.save_metadata:
131            metadata_path.write_text(json.dumps(self._trace_metadata(trace), indent=2) + "\n")
132
133    def _episode_stem(self, trace: EpisodeTrace[Any, Any, Any, Any]) -> str:
134        return (
135            f"{self.prefix}_episode_{trace.episode:04d}"
136            f"_reward_{int(round(trace.episode_reward))}"
137            f"_avg_{int(round(trace.avg_reward))}"
138        )
139
140    def _trace_metadata(
141        self,
142        trace: EpisodeTrace[Any, Any, Any, Any],
143    ) -> dict[str, Any]:
144        return {
145            "episode": trace.episode,
146            "episode_reward": trace.episode_reward,
147            "avg_reward": trace.avg_reward,
148            "step_count": trace.step_count,
149            "solved": trace.solved,
150            "frame_count": len(trace.frames),
151            "actions": [step.action for step in trace.steps],
152            "active_option_ids": [step.active_option_id for step in trace.steps],
153            "metadata": dict(trace.metadata),
154        }
155
156    def _to_pil_image(self, frame: object) -> Image.Image:
157        if isinstance(frame, Image.Image):
158            return frame.convert("RGB")
159
160        array = np.asarray(frame)
161        if array.ndim == 2:
162            array = np.stack([array] * 3, axis=-1)
163        if array.dtype != np.uint8:
164            array = np.clip(array, 0, 255).astype(np.uint8)
165        return Image.fromarray(array).convert("RGB")
166
167
168def animation_recorder_from_env(mode: str) -> EpisodeAnimationRecorder | None:
169    """Build an animation recorder from env vars with sensible defaults.
170
171    Animation capture is enabled by default for the example training entry
172    points so that long runs automatically leave behind visual artifacts.
173    Set ``OAK_EXAMPLE_DISABLE_ANIMATIONS=1`` to turn it off.
174    """
175
176    disable_token = os.environ.get("OAK_EXAMPLE_DISABLE_ANIMATIONS", "").strip().lower()
177    if disable_token in {"1", "true", "yes", "on"}:
178        return None
179
180    output_dir = (
181        os.environ.get("OAK_EXAMPLE_ANIMATION_DIR", "").strip()
182        or str(_DEFAULT_ANIMATION_OUTPUT_DIR)
183    )
184    episode_tokens = os.environ.get("OAK_EXAMPLE_ANIMATION_EPISODES", "").split(",")
185    episode_indices = tuple(
186        int(token.strip())
187        for token in episode_tokens
188        if token.strip()
189    )
190    every_n = int(
191        os.environ.get(
192            "OAK_EXAMPLE_ANIMATION_EVERY",
193            str(_DEFAULT_ANIMATION_EVERY),
194        )
195        or _DEFAULT_ANIMATION_EVERY
196    )
197    last_n = int(
198        os.environ.get(
199            "OAK_EXAMPLE_ANIMATION_LAST",
200            str(_DEFAULT_ANIMATION_LAST),
201        )
202        or _DEFAULT_ANIMATION_LAST
203    )
204    fps = int(
205        os.environ.get(
206            "OAK_EXAMPLE_ANIMATION_FPS",
207            str(_DEFAULT_ANIMATION_FPS),
208        )
209        or _DEFAULT_ANIMATION_FPS
210    )
211    max_frames = int(
212        os.environ.get(
213            "OAK_EXAMPLE_ANIMATION_MAX_FRAMES",
214            str(_DEFAULT_ANIMATION_MAX_FRAMES),
215        )
216        or _DEFAULT_ANIMATION_MAX_FRAMES
217    )
218
219    return EpisodeAnimationRecorder(
220        output_dir=Path(output_dir) / mode,
221        schedule=EpisodeCaptureSchedule(
222            episode_indices=episode_indices,
223            every_n_episodes=every_n or None,
224            last_n_episodes=last_n,
225        ),
226        prefix=mode,
227        fps=fps,
228        max_frames=max_frames or None,
229    )
230
231
232def curve_recorder_from_env(
233    mode: str,
234    *,
235    average_window: int,
236) -> TrainingCurveRecorder | None:
237    """Build a curve recorder from env vars with sensible defaults.
238
239    Curve capture is enabled by default so short training runs still write the
240    SVG/JSON artifacts used in the report. Set
241    ``OAK_EXAMPLE_DISABLE_CURVES=1`` to turn it off.
242    """
243
244    disable_token = os.environ.get("OAK_EXAMPLE_DISABLE_CURVES", "").strip().lower()
245    if disable_token in {"1", "true", "yes", "on"}:
246        return None
247
248    output_dir = (
249        os.environ.get("OAK_EXAMPLE_CURVE_DIR", "").strip()
250        or os.environ.get("OAK_EXAMPLE_PLOT_DIR", "").strip()
251        or str(_DEFAULT_CURVE_OUTPUT_DIR)
252    )
253    checkpoint_every = int(
254        os.environ.get(
255            "OAK_EXAMPLE_CURVE_CHECKPOINT_EVERY",
256            str(_DEFAULT_CURVE_CHECKPOINT_EVERY),
257        )
258        or _DEFAULT_CURVE_CHECKPOINT_EVERY
259    )
260    return TrainingCurveRecorder(
261        output_dir=Path(output_dir) / mode,
262        prefix=mode,
263        average_window=average_window,
264        checkpoint_every=checkpoint_every,
265    )
266
267
268def _write_line_plot_svg(
269    path: Path,
270    *,
271    title: str,
272    x_label: str,
273    y_label: str,
274    series: list[tuple[Any, ...]],
275) -> None:
276    """Write a lightweight SVG line chart without external plotting deps."""
277    width = 960
278    height = 540
279    left = 70
280    right = 24
281    top = 52
282    bottom = 58
283    plot_width = width - left - right
284    plot_height = height - top - bottom
285
286    values = [
287        float(value)
288        for entry in series
289        for value in entry[1]
290        if math.isfinite(float(value))
291    ]
292    if not values:
293        values = [0.0, 1.0]
294
295    y_min = min(values)
296    y_max = max(values)
297    if math.isclose(y_min, y_max):
298        pad = max(1.0, abs(y_min) * 0.1 + 1.0)
299        y_min -= pad
300        y_max += pad
301
302    x_max = max((len(entry[1]) - 1 for entry in series if entry[1]), default=1)
303    x_max = max(x_max, 1)
304
305    def x_pos(index: int) -> float:
306        return left + (index / x_max) * plot_width
307
308    def y_pos(value: float) -> float:
309        norm = (value - y_min) / (y_max - y_min)
310        return top + (1.0 - norm) * plot_height
311
312    grid_lines: list[str] = []
313    tick_labels: list[str] = []
314    for tick in range(6):
315        frac = tick / 5
316        y_value = y_max - frac * (y_max - y_min)
317        y = top + frac * plot_height
318        grid_lines.append(
319            f'<line x1="{left}" y1="{y:.2f}" x2="{width - right}" y2="{y:.2f}" '
320            'stroke="#d7dde5" stroke-width="1" />'
321        )
322        tick_labels.append(
323            f'<text x="{left - 10}" y="{y + 4:.2f}" text-anchor="end" '
324            'font-size="12" fill="#334155">'
325            f"{y_value:.2f}</text>"
326        )
327
328    x_ticks = sorted({round(tick * x_max / 5) for tick in range(6)})
329    x_tick_labels: list[str] = []
330    for x_tick in x_ticks:
331        x = x_pos(int(x_tick))
332        grid_lines.append(
333            f'<line x1="{x:.2f}" y1="{top}" x2="{x:.2f}" y2="{top + plot_height:.2f}" '
334            'stroke="#eef2f7" stroke-width="1" />'
335        )
336        x_tick_labels.append(
337            f'<text x="{x:.2f}" y="{top + plot_height + 20:.2f}" text-anchor="middle" '
338            'font-size="12" fill="#334155">'
339            f"{int(x_tick)}</text>"
340        )
341
342    polylines: list[str] = []
343    legend_items: list[str] = []
344    for index, entry in enumerate(series):
345        label = str(entry[0])
346        seq = entry[1]
347        color = str(entry[2])
348        opacity = float(entry[3]) if len(entry) >= 4 else _DEFAULT_SERIES_OPACITY
349        stroke_width = float(entry[4]) if len(entry) >= 5 else 2.5
350        points = " ".join(
351            f"{x_pos(i):.2f},{y_pos(float(value)):.2f}"
352            for i, value in enumerate(seq)
353            if math.isfinite(float(value))
354        )
355        if points:
356            polylines.append(
357                f'<polyline fill="none" stroke="{color}" stroke-width="{stroke_width}" '
358                f'stroke-opacity="{opacity:.3f}" '
359                f'stroke-linejoin="round" stroke-linecap="round" points="{points}" />'
360            )
361        legend_y = top + 16 + 20 * index
362        legend_items.append(
363            f'<rect x="{width - right - 180}" y="{legend_y - 10}" width="14" height="14" '
364            f'fill="{color}" fill-opacity="{max(opacity, 0.35):.3f}" rx="3" />'
365            f'<text x="{width - right - 160}" y="{legend_y + 1}" font-size="13" '
366            'fill="#0f172a">'
367            f"{escape(label)}</text>"
368        )
369
370    svg = f"""<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}" viewBox="0 0 {width} {height}" role="img" aria-label="{escape(title)}">
371  <rect width="100%" height="100%" fill="#f8fafc" />
372  <text x="{left}" y="28" font-size="22" font-weight="700" fill="#0f172a">{escape(title)}</text>
373  <rect x="{left}" y="{top}" width="{plot_width}" height="{plot_height}" fill="#ffffff" stroke="#cbd5e1" stroke-width="1.5" rx="10" />
374  {''.join(grid_lines)}
375  {''.join(tick_labels)}
376  {''.join(x_tick_labels)}
377  {''.join(polylines)}
378  {''.join(legend_items)}
379  <text x="{left + plot_width / 2:.2f}" y="{height - 16}" text-anchor="middle" font-size="14" fill="#334155">{escape(x_label)}</text>
380  <text x="20" y="{top + plot_height / 2:.2f}" text-anchor="middle" font-size="14" fill="#334155" transform="rotate(-90 20 {top + plot_height / 2:.2f})">{escape(y_label)}</text>
381</svg>
382"""
383    path.parent.mkdir(parents=True, exist_ok=True)
384    path.write_text(svg)
385
386
387def _hex_to_rgba(color: str, alpha: float = 1.0) -> tuple[int, int, int, int]:
388    color = color.lstrip("#")
389    if len(color) != 6:
390        return (107, 114, 128, int(255 * alpha))
391    red = int(color[0:2], 16)
392    green = int(color[2:4], 16)
393    blue = int(color[4:6], 16)
394    return (red, green, blue, max(0, min(255, int(round(255 * alpha)))))
395
396
397def _write_line_plot_png(
398    path: Path,
399    *,
400    title: str,
401    x_label: str,
402    y_label: str,
403    series: list[tuple[Any, ...]],
404) -> None:
405    """Write a lightweight PNG line chart without plotting dependencies."""
406    width = 960
407    height = 540
408    left = 70
409    right = 24
410    top = 52
411    bottom = 58
412    plot_width = width - left - right
413    plot_height = height - top - bottom
414
415    values = [
416        float(value)
417        for entry in series
418        for value in entry[1]
419        if math.isfinite(float(value))
420    ]
421    if not values:
422        values = [0.0, 1.0]
423
424    y_min = min(values)
425    y_max = max(values)
426    if math.isclose(y_min, y_max):
427        pad = max(1.0, abs(y_min) * 0.1 + 1.0)
428        y_min -= pad
429        y_max += pad
430
431    x_max = max((len(entry[1]) - 1 for entry in series if entry[1]), default=1)
432    x_max = max(x_max, 1)
433
434    def x_pos(index: int) -> float:
435        return left + (index / x_max) * plot_width
436
437    def y_pos(value: float) -> float:
438        norm = (value - y_min) / (y_max - y_min)
439        return top + (1.0 - norm) * plot_height
440
441    image = Image.new("RGBA", (width, height), (248, 250, 252, 255))
442    draw = ImageDraw.Draw(image, "RGBA")
443    font = ImageFont.load_default()
444
445    draw.rounded_rectangle(
446        (left, top, left + plot_width, top + plot_height),
447        radius=10,
448        fill=(255, 255, 255, 255),
449        outline=(203, 213, 225, 255),
450        width=2,
451    )
452    draw.text((left, 16), title, fill=(15, 23, 42, 255), font=font)
453
454    for tick in range(6):
455        frac = tick / 5
456        y_value = y_max - frac * (y_max - y_min)
457        y = top + frac * plot_height
458        draw.line(
459            ((left, y), (width - right, y)),
460            fill=(215, 221, 229, 255),
461            width=1,
462        )
463        draw.text(
464            (left - 52, y - 6),
465            f"{y_value:.2f}",
466            fill=(51, 65, 85, 255),
467            font=font,
468        )
469
470    x_ticks = sorted({round(tick * x_max / 5) for tick in range(6)})
471    for x_tick in x_ticks:
472        x = x_pos(int(x_tick))
473        draw.line(
474            ((x, top), (x, top + plot_height)),
475            fill=(238, 242, 247, 255),
476            width=1,
477        )
478        draw.text(
479            (x - 8, top + plot_height + 8),
480            str(int(x_tick)),
481            fill=(51, 65, 85, 255),
482            font=font,
483        )
484
485    for entry in series:
486        seq = entry[1]
487        color = str(entry[2])
488        opacity = float(entry[3]) if len(entry) >= 4 else _DEFAULT_SERIES_OPACITY
489        stroke_width = max(1, int(round(float(entry[4]) if len(entry) >= 5 else 2.5)))
490        points = [
491            (x_pos(i), y_pos(float(value)))
492            for i, value in enumerate(seq)
493            if math.isfinite(float(value))
494        ]
495        if len(points) >= 2:
496            draw.line(
497                points,
498                fill=_hex_to_rgba(color, opacity),
499                width=stroke_width,
500                joint="curve",
501            )
502
503    legend_x = width - right - 180
504    for index, entry in enumerate(series):
505        label = str(entry[0])
506        color = str(entry[2])
507        opacity = float(entry[3]) if len(entry) >= 4 else _DEFAULT_SERIES_OPACITY
508        legend_y = top + 6 + 20 * index
509        draw.rounded_rectangle(
510            (legend_x, legend_y, legend_x + 14, legend_y + 14),
511            radius=3,
512            fill=_hex_to_rgba(color, max(opacity, 0.35)),
513        )
514        draw.text(
515            (legend_x + 20, legend_y),
516            label,
517            fill=(15, 23, 42, 255),
518            font=font,
519        )
520
521    draw.text(
522        (left + plot_width / 2 - 25, height - 24),
523        x_label,
524        fill=(51, 65, 85, 255),
525        font=font,
526    )
527    draw.text((8, top + plot_height / 2), y_label, fill=(51, 65, 85, 255), font=font)
528
529    path.parent.mkdir(parents=True, exist_ok=True)
530    image.convert("RGB").save(path, format="PNG")
531
532
533def _write_line_plot_assets(
534    stem: Path,
535    *,
536    title: str,
537    x_label: str,
538    y_label: str,
539    series: list[tuple[Any, ...]],
540) -> None:
541    _write_line_plot_svg(
542        stem.with_suffix(".svg"),
543        title=title,
544        x_label=x_label,
545        y_label=y_label,
546        series=series,
547    )
548    _write_line_plot_png(
549        stem.with_suffix(".png"),
550        title=title,
551        x_label=x_label,
552        y_label=y_label,
553        series=series,
554    )
555
556
557@dataclass(slots=True)
558class TrainingCurveRecorder:
559    """Persist reward and moving-average curves for a full training run."""
560
561    output_dir: Path
562    prefix: str = "training"
563    average_window: int = 100
564    reward_color: str = "#264653"
565    average_color: str = "#2a9d8f"
566    epsilon_color: str = "#e76f51"
567    option_color: str = "#457b9d"
568    reward_history: list[float] = field(default_factory=list)
569    average_history: list[float] = field(default_factory=list)
570    epsilon_history: list[float] = field(default_factory=list)
571    option_history: list[float] = field(default_factory=list)
572    metric_histories: dict[str, list[float]] = field(default_factory=dict)
573    checkpoint_every: int = 0
574    metric_colors: dict[str, str] = field(
575        default_factory=lambda: {
576            "perception_encoder_loss": "#8a5cf6",
577            "value_q_omega_loss": "#c0392b",
578            "value_gvf_loss": "#d35400",
579            "policy_q_loss": "#2980b9",
580            "policy_termination_loss": "#16a085",
581            "model_loss": "#7f8c8d",
582            "model_done_loss": "#34495e",
583        }
584    )
585
586    def log_episode(
587        self,
588        episode: int,
589        reward: float,
590        avg_reward: float,
591        agent: Any,
592    ) -> None:
593        """Record one episode's aggregate metrics."""
594        current_episode = int(episode)
595        self.reward_history.append(float(reward))
596        self.average_history.append(float(avg_reward))
597        self.epsilon_history.append(
598            float(getattr(agent.reactive_policy, "epsilon", 0.0))
599        )
600        self.option_history.append(
601            float(len(getattr(agent.reactive_policy, "_options", {})))
602        )
603        current_metrics = _episode_metric_snapshot(agent)
604        episode_count = len(self.reward_history)
605        for metric_name in current_metrics:
606            self.metric_histories.setdefault(
607                metric_name,
608                [float("nan")] * (episode_count - 1),
609            )
610        for metric_name, history in self.metric_histories.items():
611            history.append(float(current_metrics.get(metric_name, float("nan"))))
612        if self.checkpoint_every > 0 and (current_episode + 1) % self.checkpoint_every == 0:
613            self.save_checkpoint()
614
615    def _history_payload(self) -> dict[str, Any]:
616        return {
617            "average_window": self.average_window,
618            "episodes": len(self.reward_history),
619            "reward_history": self.reward_history,
620            "average_history": self.average_history,
621            "epsilon_history": self.epsilon_history,
622            "option_history": self.option_history,
623            "metric_histories": self.metric_histories,
624            "best_reward": max(self.reward_history),
625            "final_reward": self.reward_history[-1],
626            "final_average": self.average_history[-1],
627        }
628
629    def save_checkpoint(self) -> None:
630        """Write the raw training histories without rendering plots."""
631        if not self.reward_history:
632            return
633        self.output_dir.mkdir(parents=True, exist_ok=True)
634        stem = self.output_dir / self.prefix
635        payload = self._history_payload()
636        payload["checkpoint"] = True
637        stem.with_name(f"{self.prefix}_reward_history_checkpoint.json").write_text(
638            json.dumps(payload, indent=2) + "\n"
639        )
640
641    def save(self) -> None:
642        """Write SVG plots and the raw history JSON."""
643        if not self.reward_history:
644            return
645
646        self.output_dir.mkdir(parents=True, exist_ok=True)
647        stem = self.output_dir / self.prefix
648
649        _write_line_plot_assets(
650            stem.with_name(f"{self.prefix}_reward_curve"),
651            title=f"{self.prefix} reward over episodes",
652            x_label="Episode",
653            y_label="Reward",
654            series=[
655                (
656                    "reward",
657                    self.reward_history,
658                    self.reward_color,
659                    _RAW_SERIES_OPACITY,
660                    _RAW_SERIES_WIDTH,
661                ),
662                (
663                    f"avg{self.average_window}",
664                    _moving_average(self.reward_history, self.average_window),
665                    self.average_color,
666                    _DEFAULT_SERIES_OPACITY,
667                    _AVG_SERIES_WIDTH,
668                ),
669            ],
670        )
671        _write_line_plot_assets(
672            stem.with_name(f"{self.prefix}_training_state"),
673            title=f"{self.prefix} exploration and option count",
674            x_label="Episode",
675            y_label="Value",
676            series=[
677                ("epsilon", self.epsilon_history, self.epsilon_color),
678                ("options", self.option_history, self.option_color),
679            ],
680        )
681
682        for metric_name, history in sorted(self.metric_histories.items()):
683            if not history:
684                continue
685            metric_label = metric_name.replace("_", " ")
686            color = self.metric_colors.get(metric_name, "#6b7280")
687            smoothed = _moving_average(history, self.average_window)
688            _write_line_plot_assets(
689                stem.with_name(f"{self.prefix}_{metric_name}"),
690                title=f"{self.prefix} {metric_label} over episodes",
691                x_label="Episode",
692                y_label=metric_label,
693                series=[
694                    (
695                        f"{metric_label} raw",
696                        history,
697                        color,
698                        _RAW_SERIES_OPACITY,
699                        _RAW_SERIES_WIDTH,
700                    ),
701                    (
702                        f"{metric_label} avg{self.average_window}",
703                        smoothed,
704                        color,
705                        _DEFAULT_SERIES_OPACITY,
706                        _AVG_SERIES_WIDTH,
707                    ),
708                ],
709            )
710
711        payload = self._history_payload()
712        stem.with_name(f"{self.prefix}_reward_history.json").write_text(
713            json.dumps(payload, indent=2) + "\n"
714        )
@dataclass(slots=True, frozen=True)
class EpisodeCaptureSchedule:
69@dataclass(slots=True, frozen=True)
70class EpisodeCaptureSchedule:
71    """Select which training episodes should produce artifacts."""
72
73    episode_indices: tuple[int, ...] = ()
74    every_n_episodes: int | None = None
75    last_n_episodes: int = 0
76
77    def should_capture(self, episode: int, total_episodes: int) -> bool:
78        if episode in self.episode_indices:
79            return True
80        if self.every_n_episodes is not None and self.every_n_episodes > 0:
81            if episode % self.every_n_episodes == 0:
82                return True
83        if self.last_n_episodes > 0 and episode >= max(total_episodes - self.last_n_episodes, 0):
84            return True
85        return False

Select which training episodes should produce artifacts.

EpisodeCaptureSchedule( episode_indices: 'tuple[int, ...]' = (), every_n_episodes: 'int | None' = None, last_n_episodes: 'int' = 0)
episode_indices: 'tuple[int, ...]'
every_n_episodes: 'int | None'
last_n_episodes: 'int'
def should_capture(self, episode: 'int', total_episodes: 'int') -> 'bool':
77    def should_capture(self, episode: int, total_episodes: int) -> bool:
78        if episode in self.episode_indices:
79            return True
80        if self.every_n_episodes is not None and self.every_n_episodes > 0:
81            if episode % self.every_n_episodes == 0:
82                return True
83        if self.last_n_episodes > 0 and episode >= max(total_episodes - self.last_n_episodes, 0):
84            return True
85        return False
@dataclass(slots=True)
class EpisodeAnimationRecorder:
 88@dataclass(slots=True)
 89class EpisodeAnimationRecorder:
 90    """Save traced episode frames as GIF animations plus compact metadata."""
 91
 92    output_dir: Path
 93    schedule: EpisodeCaptureSchedule = field(default_factory=EpisodeCaptureSchedule)
 94    prefix: str = "training"
 95    fps: int = 30
 96    max_frames: int | None = None
 97    save_metadata: bool = True
 98
 99    def should_capture(self, episode: int, total_episodes: int) -> bool:
100        return self.schedule.should_capture(episode, total_episodes)
101
102    def __call__(self, trace: EpisodeTrace[Any, Any, Any, Any]) -> None:
103        self.output_dir.mkdir(parents=True, exist_ok=True)
104        stem = self._episode_stem(trace)
105        metadata_path = self.output_dir / f"{stem}.json"
106        gif_path = self.output_dir / f"{stem}.gif"
107
108        frames = list(trace.frames)
109        if self.max_frames is not None and self.max_frames > 0:
110            frames = frames[: self.max_frames]
111        if not frames:
112            if self.save_metadata:
113                metadata = self._trace_metadata(trace)
114                metadata["warning"] = (
115                    "No frames were captured for this episode. Ensure the world was "
116                    "created with render_mode='rgb_array' and the environment's "
117                    "rendering dependencies are installed."
118                )
119                metadata_path.write_text(json.dumps(metadata, indent=2) + "\n")
120            return
121
122        pil_frames = [self._to_pil_image(frame) for frame in frames]
123        pil_frames[0].save(
124            gif_path,
125            save_all=True,
126            append_images=pil_frames[1:],
127            duration=max(int(1000 / max(self.fps, 1)), 1),
128            loop=0,
129        )
130
131        if self.save_metadata:
132            metadata_path.write_text(json.dumps(self._trace_metadata(trace), indent=2) + "\n")
133
134    def _episode_stem(self, trace: EpisodeTrace[Any, Any, Any, Any]) -> str:
135        return (
136            f"{self.prefix}_episode_{trace.episode:04d}"
137            f"_reward_{int(round(trace.episode_reward))}"
138            f"_avg_{int(round(trace.avg_reward))}"
139        )
140
141    def _trace_metadata(
142        self,
143        trace: EpisodeTrace[Any, Any, Any, Any],
144    ) -> dict[str, Any]:
145        return {
146            "episode": trace.episode,
147            "episode_reward": trace.episode_reward,
148            "avg_reward": trace.avg_reward,
149            "step_count": trace.step_count,
150            "solved": trace.solved,
151            "frame_count": len(trace.frames),
152            "actions": [step.action for step in trace.steps],
153            "active_option_ids": [step.active_option_id for step in trace.steps],
154            "metadata": dict(trace.metadata),
155        }
156
157    def _to_pil_image(self, frame: object) -> Image.Image:
158        if isinstance(frame, Image.Image):
159            return frame.convert("RGB")
160
161        array = np.asarray(frame)
162        if array.ndim == 2:
163            array = np.stack([array] * 3, axis=-1)
164        if array.dtype != np.uint8:
165            array = np.clip(array, 0, 255).astype(np.uint8)
166        return Image.fromarray(array).convert("RGB")

Save traced episode frames as GIF animations plus compact metadata.

EpisodeAnimationRecorder( output_dir: 'Path', schedule: 'EpisodeCaptureSchedule' = <factory>, prefix: 'str' = 'training', fps: 'int' = 30, max_frames: 'int | None' = None, save_metadata: 'bool' = True)
output_dir: 'Path'
schedule: 'EpisodeCaptureSchedule'
prefix: 'str'
fps: 'int'
max_frames: 'int | None'
save_metadata: 'bool'
def should_capture(self, episode: 'int', total_episodes: 'int') -> 'bool':
 99    def should_capture(self, episode: int, total_episodes: int) -> bool:
100        return self.schedule.should_capture(episode, total_episodes)
def animation_recorder_from_env(mode: 'str') -> 'EpisodeAnimationRecorder | None':
169def animation_recorder_from_env(mode: str) -> EpisodeAnimationRecorder | None:
170    """Build an animation recorder from env vars with sensible defaults.
171
172    Animation capture is enabled by default for the example training entry
173    points so that long runs automatically leave behind visual artifacts.
174    Set ``OAK_EXAMPLE_DISABLE_ANIMATIONS=1`` to turn it off.
175    """
176
177    disable_token = os.environ.get("OAK_EXAMPLE_DISABLE_ANIMATIONS", "").strip().lower()
178    if disable_token in {"1", "true", "yes", "on"}:
179        return None
180
181    output_dir = (
182        os.environ.get("OAK_EXAMPLE_ANIMATION_DIR", "").strip()
183        or str(_DEFAULT_ANIMATION_OUTPUT_DIR)
184    )
185    episode_tokens = os.environ.get("OAK_EXAMPLE_ANIMATION_EPISODES", "").split(",")
186    episode_indices = tuple(
187        int(token.strip())
188        for token in episode_tokens
189        if token.strip()
190    )
191    every_n = int(
192        os.environ.get(
193            "OAK_EXAMPLE_ANIMATION_EVERY",
194            str(_DEFAULT_ANIMATION_EVERY),
195        )
196        or _DEFAULT_ANIMATION_EVERY
197    )
198    last_n = int(
199        os.environ.get(
200            "OAK_EXAMPLE_ANIMATION_LAST",
201            str(_DEFAULT_ANIMATION_LAST),
202        )
203        or _DEFAULT_ANIMATION_LAST
204    )
205    fps = int(
206        os.environ.get(
207            "OAK_EXAMPLE_ANIMATION_FPS",
208            str(_DEFAULT_ANIMATION_FPS),
209        )
210        or _DEFAULT_ANIMATION_FPS
211    )
212    max_frames = int(
213        os.environ.get(
214            "OAK_EXAMPLE_ANIMATION_MAX_FRAMES",
215            str(_DEFAULT_ANIMATION_MAX_FRAMES),
216        )
217        or _DEFAULT_ANIMATION_MAX_FRAMES
218    )
219
220    return EpisodeAnimationRecorder(
221        output_dir=Path(output_dir) / mode,
222        schedule=EpisodeCaptureSchedule(
223            episode_indices=episode_indices,
224            every_n_episodes=every_n or None,
225            last_n_episodes=last_n,
226        ),
227        prefix=mode,
228        fps=fps,
229        max_frames=max_frames or None,
230    )

Build an animation recorder from env vars with sensible defaults.

Animation capture is enabled by default for the example training entry points so that long runs automatically leave behind visual artifacts. Set OAK_EXAMPLE_DISABLE_ANIMATIONS=1 to turn it off.

def curve_recorder_from_env(mode: 'str', *, average_window: 'int') -> 'TrainingCurveRecorder | None':
233def curve_recorder_from_env(
234    mode: str,
235    *,
236    average_window: int,
237) -> TrainingCurveRecorder | None:
238    """Build a curve recorder from env vars with sensible defaults.
239
240    Curve capture is enabled by default so short training runs still write the
241    SVG/JSON artifacts used in the report. Set
242    ``OAK_EXAMPLE_DISABLE_CURVES=1`` to turn it off.
243    """
244
245    disable_token = os.environ.get("OAK_EXAMPLE_DISABLE_CURVES", "").strip().lower()
246    if disable_token in {"1", "true", "yes", "on"}:
247        return None
248
249    output_dir = (
250        os.environ.get("OAK_EXAMPLE_CURVE_DIR", "").strip()
251        or os.environ.get("OAK_EXAMPLE_PLOT_DIR", "").strip()
252        or str(_DEFAULT_CURVE_OUTPUT_DIR)
253    )
254    checkpoint_every = int(
255        os.environ.get(
256            "OAK_EXAMPLE_CURVE_CHECKPOINT_EVERY",
257            str(_DEFAULT_CURVE_CHECKPOINT_EVERY),
258        )
259        or _DEFAULT_CURVE_CHECKPOINT_EVERY
260    )
261    return TrainingCurveRecorder(
262        output_dir=Path(output_dir) / mode,
263        prefix=mode,
264        average_window=average_window,
265        checkpoint_every=checkpoint_every,
266    )

Build a curve recorder from env vars with sensible defaults.

Curve capture is enabled by default so short training runs still write the SVG/JSON artifacts used in the report. Set OAK_EXAMPLE_DISABLE_CURVES=1 to turn it off.

@dataclass(slots=True)
class TrainingCurveRecorder:
558@dataclass(slots=True)
559class TrainingCurveRecorder:
560    """Persist reward and moving-average curves for a full training run."""
561
562    output_dir: Path
563    prefix: str = "training"
564    average_window: int = 100
565    reward_color: str = "#264653"
566    average_color: str = "#2a9d8f"
567    epsilon_color: str = "#e76f51"
568    option_color: str = "#457b9d"
569    reward_history: list[float] = field(default_factory=list)
570    average_history: list[float] = field(default_factory=list)
571    epsilon_history: list[float] = field(default_factory=list)
572    option_history: list[float] = field(default_factory=list)
573    metric_histories: dict[str, list[float]] = field(default_factory=dict)
574    checkpoint_every: int = 0
575    metric_colors: dict[str, str] = field(
576        default_factory=lambda: {
577            "perception_encoder_loss": "#8a5cf6",
578            "value_q_omega_loss": "#c0392b",
579            "value_gvf_loss": "#d35400",
580            "policy_q_loss": "#2980b9",
581            "policy_termination_loss": "#16a085",
582            "model_loss": "#7f8c8d",
583            "model_done_loss": "#34495e",
584        }
585    )
586
587    def log_episode(
588        self,
589        episode: int,
590        reward: float,
591        avg_reward: float,
592        agent: Any,
593    ) -> None:
594        """Record one episode's aggregate metrics."""
595        current_episode = int(episode)
596        self.reward_history.append(float(reward))
597        self.average_history.append(float(avg_reward))
598        self.epsilon_history.append(
599            float(getattr(agent.reactive_policy, "epsilon", 0.0))
600        )
601        self.option_history.append(
602            float(len(getattr(agent.reactive_policy, "_options", {})))
603        )
604        current_metrics = _episode_metric_snapshot(agent)
605        episode_count = len(self.reward_history)
606        for metric_name in current_metrics:
607            self.metric_histories.setdefault(
608                metric_name,
609                [float("nan")] * (episode_count - 1),
610            )
611        for metric_name, history in self.metric_histories.items():
612            history.append(float(current_metrics.get(metric_name, float("nan"))))
613        if self.checkpoint_every > 0 and (current_episode + 1) % self.checkpoint_every == 0:
614            self.save_checkpoint()
615
616    def _history_payload(self) -> dict[str, Any]:
617        return {
618            "average_window": self.average_window,
619            "episodes": len(self.reward_history),
620            "reward_history": self.reward_history,
621            "average_history": self.average_history,
622            "epsilon_history": self.epsilon_history,
623            "option_history": self.option_history,
624            "metric_histories": self.metric_histories,
625            "best_reward": max(self.reward_history),
626            "final_reward": self.reward_history[-1],
627            "final_average": self.average_history[-1],
628        }
629
630    def save_checkpoint(self) -> None:
631        """Write the raw training histories without rendering plots."""
632        if not self.reward_history:
633            return
634        self.output_dir.mkdir(parents=True, exist_ok=True)
635        stem = self.output_dir / self.prefix
636        payload = self._history_payload()
637        payload["checkpoint"] = True
638        stem.with_name(f"{self.prefix}_reward_history_checkpoint.json").write_text(
639            json.dumps(payload, indent=2) + "\n"
640        )
641
642    def save(self) -> None:
643        """Write SVG plots and the raw history JSON."""
644        if not self.reward_history:
645            return
646
647        self.output_dir.mkdir(parents=True, exist_ok=True)
648        stem = self.output_dir / self.prefix
649
650        _write_line_plot_assets(
651            stem.with_name(f"{self.prefix}_reward_curve"),
652            title=f"{self.prefix} reward over episodes",
653            x_label="Episode",
654            y_label="Reward",
655            series=[
656                (
657                    "reward",
658                    self.reward_history,
659                    self.reward_color,
660                    _RAW_SERIES_OPACITY,
661                    _RAW_SERIES_WIDTH,
662                ),
663                (
664                    f"avg{self.average_window}",
665                    _moving_average(self.reward_history, self.average_window),
666                    self.average_color,
667                    _DEFAULT_SERIES_OPACITY,
668                    _AVG_SERIES_WIDTH,
669                ),
670            ],
671        )
672        _write_line_plot_assets(
673            stem.with_name(f"{self.prefix}_training_state"),
674            title=f"{self.prefix} exploration and option count",
675            x_label="Episode",
676            y_label="Value",
677            series=[
678                ("epsilon", self.epsilon_history, self.epsilon_color),
679                ("options", self.option_history, self.option_color),
680            ],
681        )
682
683        for metric_name, history in sorted(self.metric_histories.items()):
684            if not history:
685                continue
686            metric_label = metric_name.replace("_", " ")
687            color = self.metric_colors.get(metric_name, "#6b7280")
688            smoothed = _moving_average(history, self.average_window)
689            _write_line_plot_assets(
690                stem.with_name(f"{self.prefix}_{metric_name}"),
691                title=f"{self.prefix} {metric_label} over episodes",
692                x_label="Episode",
693                y_label=metric_label,
694                series=[
695                    (
696                        f"{metric_label} raw",
697                        history,
698                        color,
699                        _RAW_SERIES_OPACITY,
700                        _RAW_SERIES_WIDTH,
701                    ),
702                    (
703                        f"{metric_label} avg{self.average_window}",
704                        smoothed,
705                        color,
706                        _DEFAULT_SERIES_OPACITY,
707                        _AVG_SERIES_WIDTH,
708                    ),
709                ],
710            )
711
712        payload = self._history_payload()
713        stem.with_name(f"{self.prefix}_reward_history.json").write_text(
714            json.dumps(payload, indent=2) + "\n"
715        )

Persist reward and moving-average curves for a full training run.

TrainingCurveRecorder( output_dir: 'Path', prefix: 'str' = 'training', average_window: 'int' = 100, reward_color: 'str' = '#264653', average_color: 'str' = '#2a9d8f', epsilon_color: 'str' = '#e76f51', option_color: 'str' = '#457b9d', reward_history: 'list[float]' = <factory>, average_history: 'list[float]' = <factory>, epsilon_history: 'list[float]' = <factory>, option_history: 'list[float]' = <factory>, metric_histories: 'dict[str, list[float]]' = <factory>, checkpoint_every: 'int' = 0, metric_colors: 'dict[str, str]' = <factory>)
output_dir: 'Path'
prefix: 'str'
average_window: 'int'
reward_color: 'str'
average_color: 'str'
epsilon_color: 'str'
option_color: 'str'
reward_history: 'list[float]'
average_history: 'list[float]'
epsilon_history: 'list[float]'
option_history: 'list[float]'
metric_histories: 'dict[str, list[float]]'
checkpoint_every: 'int'
metric_colors: 'dict[str, str]'
def log_episode( self, episode: 'int', reward: 'float', avg_reward: 'float', agent: 'Any') -> 'None':
587    def log_episode(
588        self,
589        episode: int,
590        reward: float,
591        avg_reward: float,
592        agent: Any,
593    ) -> None:
594        """Record one episode's aggregate metrics."""
595        current_episode = int(episode)
596        self.reward_history.append(float(reward))
597        self.average_history.append(float(avg_reward))
598        self.epsilon_history.append(
599            float(getattr(agent.reactive_policy, "epsilon", 0.0))
600        )
601        self.option_history.append(
602            float(len(getattr(agent.reactive_policy, "_options", {})))
603        )
604        current_metrics = _episode_metric_snapshot(agent)
605        episode_count = len(self.reward_history)
606        for metric_name in current_metrics:
607            self.metric_histories.setdefault(
608                metric_name,
609                [float("nan")] * (episode_count - 1),
610            )
611        for metric_name, history in self.metric_histories.items():
612            history.append(float(current_metrics.get(metric_name, float("nan"))))
613        if self.checkpoint_every > 0 and (current_episode + 1) % self.checkpoint_every == 0:
614            self.save_checkpoint()

Record one episode's aggregate metrics.

def save_checkpoint(self) -> 'None':
630    def save_checkpoint(self) -> None:
631        """Write the raw training histories without rendering plots."""
632        if not self.reward_history:
633            return
634        self.output_dir.mkdir(parents=True, exist_ok=True)
635        stem = self.output_dir / self.prefix
636        payload = self._history_payload()
637        payload["checkpoint"] = True
638        stem.with_name(f"{self.prefix}_reward_history_checkpoint.json").write_text(
639            json.dumps(payload, indent=2) + "\n"
640        )

Write the raw training histories without rendering plots.

def save(self) -> 'None':
642    def save(self) -> None:
643        """Write SVG plots and the raw history JSON."""
644        if not self.reward_history:
645            return
646
647        self.output_dir.mkdir(parents=True, exist_ok=True)
648        stem = self.output_dir / self.prefix
649
650        _write_line_plot_assets(
651            stem.with_name(f"{self.prefix}_reward_curve"),
652            title=f"{self.prefix} reward over episodes",
653            x_label="Episode",
654            y_label="Reward",
655            series=[
656                (
657                    "reward",
658                    self.reward_history,
659                    self.reward_color,
660                    _RAW_SERIES_OPACITY,
661                    _RAW_SERIES_WIDTH,
662                ),
663                (
664                    f"avg{self.average_window}",
665                    _moving_average(self.reward_history, self.average_window),
666                    self.average_color,
667                    _DEFAULT_SERIES_OPACITY,
668                    _AVG_SERIES_WIDTH,
669                ),
670            ],
671        )
672        _write_line_plot_assets(
673            stem.with_name(f"{self.prefix}_training_state"),
674            title=f"{self.prefix} exploration and option count",
675            x_label="Episode",
676            y_label="Value",
677            series=[
678                ("epsilon", self.epsilon_history, self.epsilon_color),
679                ("options", self.option_history, self.option_color),
680            ],
681        )
682
683        for metric_name, history in sorted(self.metric_histories.items()):
684            if not history:
685                continue
686            metric_label = metric_name.replace("_", " ")
687            color = self.metric_colors.get(metric_name, "#6b7280")
688            smoothed = _moving_average(history, self.average_window)
689            _write_line_plot_assets(
690                stem.with_name(f"{self.prefix}_{metric_name}"),
691                title=f"{self.prefix} {metric_label} over episodes",
692                x_label="Episode",
693                y_label=metric_label,
694                series=[
695                    (
696                        f"{metric_label} raw",
697                        history,
698                        color,
699                        _RAW_SERIES_OPACITY,
700                        _RAW_SERIES_WIDTH,
701                    ),
702                    (
703                        f"{metric_label} avg{self.average_window}",
704                        smoothed,
705                        color,
706                        _DEFAULT_SERIES_OPACITY,
707                        _AVG_SERIES_WIDTH,
708                    ),
709                ],
710            )
711
712        payload = self._history_payload()
713        stem.with_name(f"{self.prefix}_reward_history.json").write_text(
714            json.dumps(payload, indent=2) + "\n"
715        )

Write SVG plots and the raw history JSON.