examples.example_01.training_logging
Training-time logging helpers for Example 01.
1"""Training-time logging helpers for Example 01.""" 2 3from __future__ import annotations 4 5import json 6import math 7import os 8from dataclasses import dataclass, field 9from html import escape 10from pathlib import Path 11from typing import Any 12 13import numpy as np 14from PIL import Image, ImageDraw, ImageFont 15 16from oak.types import EpisodeTrace 17 18_DEFAULT_ANIMATION_OUTPUT_DIR = Path("tests/results/animations") 19_DEFAULT_ANIMATION_EVERY = 100 20_DEFAULT_ANIMATION_LAST = 1 21_DEFAULT_ANIMATION_FPS = 30 22_DEFAULT_ANIMATION_MAX_FRAMES = 400 23_DEFAULT_CURVE_OUTPUT_DIR = Path("tests/results/training_curves") 24_DEFAULT_CURVE_CHECKPOINT_EVERY = 10 25_DEFAULT_SERIES_OPACITY = 1.0 26_RAW_SERIES_OPACITY = 0.25 27_RAW_SERIES_WIDTH = 1.5 28_AVG_SERIES_WIDTH = 2.75 29 30 31def _moving_average(values: list[float], window: int) -> list[float]: 32 if not values: 33 return [] 34 window = max(window, 1) 35 averaged: list[float] = [] 36 for index, value in enumerate(values): 37 del value 38 start = max(0, index - window + 1) 39 window_values = [ 40 float(candidate) 41 for candidate in values[start : index + 1] 42 if math.isfinite(float(candidate)) 43 ] 44 if not window_values: 45 averaged.append(float("nan")) 46 else: 47 averaged.append(sum(window_values) / len(window_values)) 48 return averaged 49 50 51def _episode_metric_snapshot(agent: Any) -> dict[str, float]: 52 metrics: dict[str, float] = {} 53 for component_name in ( 54 "perception", 55 "value_function", 56 "reactive_policy", 57 "transition_model", 58 ): 59 component = getattr(agent, component_name, None) 60 metric_fn = getattr(component, "training_metrics", None) 61 if callable(metric_fn): 62 for key, value in dict(metric_fn()).items(): 63 if math.isfinite(float(value)): 64 metrics[key] = float(value) 65 return metrics 66 67 68@dataclass(slots=True, frozen=True) 69class EpisodeCaptureSchedule: 70 """Select which training episodes should produce artifacts.""" 71 72 episode_indices: tuple[int, ...] = () 73 every_n_episodes: int | None = None 74 last_n_episodes: int = 0 75 76 def should_capture(self, episode: int, total_episodes: int) -> bool: 77 if episode in self.episode_indices: 78 return True 79 if self.every_n_episodes is not None and self.every_n_episodes > 0: 80 if episode % self.every_n_episodes == 0: 81 return True 82 if self.last_n_episodes > 0 and episode >= max(total_episodes - self.last_n_episodes, 0): 83 return True 84 return False 85 86 87@dataclass(slots=True) 88class EpisodeAnimationRecorder: 89 """Save traced episode frames as GIF animations plus compact metadata.""" 90 91 output_dir: Path 92 schedule: EpisodeCaptureSchedule = field(default_factory=EpisodeCaptureSchedule) 93 prefix: str = "training" 94 fps: int = 30 95 max_frames: int | None = None 96 save_metadata: bool = True 97 98 def should_capture(self, episode: int, total_episodes: int) -> bool: 99 return self.schedule.should_capture(episode, total_episodes) 100 101 def __call__(self, trace: EpisodeTrace[Any, Any, Any, Any]) -> None: 102 self.output_dir.mkdir(parents=True, exist_ok=True) 103 stem = self._episode_stem(trace) 104 metadata_path = self.output_dir / f"{stem}.json" 105 gif_path = self.output_dir / f"{stem}.gif" 106 107 frames = list(trace.frames) 108 if self.max_frames is not None and self.max_frames > 0: 109 frames = frames[: self.max_frames] 110 if not frames: 111 if self.save_metadata: 112 metadata = self._trace_metadata(trace) 113 metadata["warning"] = ( 114 "No frames were captured for this episode. Ensure the world was " 115 "created with render_mode='rgb_array' and the environment's " 116 "rendering dependencies are installed." 117 ) 118 metadata_path.write_text(json.dumps(metadata, indent=2) + "\n") 119 return 120 121 pil_frames = [self._to_pil_image(frame) for frame in frames] 122 pil_frames[0].save( 123 gif_path, 124 save_all=True, 125 append_images=pil_frames[1:], 126 duration=max(int(1000 / max(self.fps, 1)), 1), 127 loop=0, 128 ) 129 130 if self.save_metadata: 131 metadata_path.write_text(json.dumps(self._trace_metadata(trace), indent=2) + "\n") 132 133 def _episode_stem(self, trace: EpisodeTrace[Any, Any, Any, Any]) -> str: 134 return ( 135 f"{self.prefix}_episode_{trace.episode:04d}" 136 f"_reward_{int(round(trace.episode_reward))}" 137 f"_avg_{int(round(trace.avg_reward))}" 138 ) 139 140 def _trace_metadata( 141 self, 142 trace: EpisodeTrace[Any, Any, Any, Any], 143 ) -> dict[str, Any]: 144 return { 145 "episode": trace.episode, 146 "episode_reward": trace.episode_reward, 147 "avg_reward": trace.avg_reward, 148 "step_count": trace.step_count, 149 "solved": trace.solved, 150 "frame_count": len(trace.frames), 151 "actions": [step.action for step in trace.steps], 152 "active_option_ids": [step.active_option_id for step in trace.steps], 153 "metadata": dict(trace.metadata), 154 } 155 156 def _to_pil_image(self, frame: object) -> Image.Image: 157 if isinstance(frame, Image.Image): 158 return frame.convert("RGB") 159 160 array = np.asarray(frame) 161 if array.ndim == 2: 162 array = np.stack([array] * 3, axis=-1) 163 if array.dtype != np.uint8: 164 array = np.clip(array, 0, 255).astype(np.uint8) 165 return Image.fromarray(array).convert("RGB") 166 167 168def animation_recorder_from_env(mode: str) -> EpisodeAnimationRecorder | None: 169 """Build an animation recorder from env vars with sensible defaults. 170 171 Animation capture is enabled by default for the example training entry 172 points so that long runs automatically leave behind visual artifacts. 173 Set ``OAK_EXAMPLE_DISABLE_ANIMATIONS=1`` to turn it off. 174 """ 175 176 disable_token = os.environ.get("OAK_EXAMPLE_DISABLE_ANIMATIONS", "").strip().lower() 177 if disable_token in {"1", "true", "yes", "on"}: 178 return None 179 180 output_dir = ( 181 os.environ.get("OAK_EXAMPLE_ANIMATION_DIR", "").strip() 182 or str(_DEFAULT_ANIMATION_OUTPUT_DIR) 183 ) 184 episode_tokens = os.environ.get("OAK_EXAMPLE_ANIMATION_EPISODES", "").split(",") 185 episode_indices = tuple( 186 int(token.strip()) 187 for token in episode_tokens 188 if token.strip() 189 ) 190 every_n = int( 191 os.environ.get( 192 "OAK_EXAMPLE_ANIMATION_EVERY", 193 str(_DEFAULT_ANIMATION_EVERY), 194 ) 195 or _DEFAULT_ANIMATION_EVERY 196 ) 197 last_n = int( 198 os.environ.get( 199 "OAK_EXAMPLE_ANIMATION_LAST", 200 str(_DEFAULT_ANIMATION_LAST), 201 ) 202 or _DEFAULT_ANIMATION_LAST 203 ) 204 fps = int( 205 os.environ.get( 206 "OAK_EXAMPLE_ANIMATION_FPS", 207 str(_DEFAULT_ANIMATION_FPS), 208 ) 209 or _DEFAULT_ANIMATION_FPS 210 ) 211 max_frames = int( 212 os.environ.get( 213 "OAK_EXAMPLE_ANIMATION_MAX_FRAMES", 214 str(_DEFAULT_ANIMATION_MAX_FRAMES), 215 ) 216 or _DEFAULT_ANIMATION_MAX_FRAMES 217 ) 218 219 return EpisodeAnimationRecorder( 220 output_dir=Path(output_dir) / mode, 221 schedule=EpisodeCaptureSchedule( 222 episode_indices=episode_indices, 223 every_n_episodes=every_n or None, 224 last_n_episodes=last_n, 225 ), 226 prefix=mode, 227 fps=fps, 228 max_frames=max_frames or None, 229 ) 230 231 232def curve_recorder_from_env( 233 mode: str, 234 *, 235 average_window: int, 236) -> TrainingCurveRecorder | None: 237 """Build a curve recorder from env vars with sensible defaults. 238 239 Curve capture is enabled by default so short training runs still write the 240 SVG/JSON artifacts used in the report. Set 241 ``OAK_EXAMPLE_DISABLE_CURVES=1`` to turn it off. 242 """ 243 244 disable_token = os.environ.get("OAK_EXAMPLE_DISABLE_CURVES", "").strip().lower() 245 if disable_token in {"1", "true", "yes", "on"}: 246 return None 247 248 output_dir = ( 249 os.environ.get("OAK_EXAMPLE_CURVE_DIR", "").strip() 250 or os.environ.get("OAK_EXAMPLE_PLOT_DIR", "").strip() 251 or str(_DEFAULT_CURVE_OUTPUT_DIR) 252 ) 253 checkpoint_every = int( 254 os.environ.get( 255 "OAK_EXAMPLE_CURVE_CHECKPOINT_EVERY", 256 str(_DEFAULT_CURVE_CHECKPOINT_EVERY), 257 ) 258 or _DEFAULT_CURVE_CHECKPOINT_EVERY 259 ) 260 return TrainingCurveRecorder( 261 output_dir=Path(output_dir) / mode, 262 prefix=mode, 263 average_window=average_window, 264 checkpoint_every=checkpoint_every, 265 ) 266 267 268def _write_line_plot_svg( 269 path: Path, 270 *, 271 title: str, 272 x_label: str, 273 y_label: str, 274 series: list[tuple[Any, ...]], 275) -> None: 276 """Write a lightweight SVG line chart without external plotting deps.""" 277 width = 960 278 height = 540 279 left = 70 280 right = 24 281 top = 52 282 bottom = 58 283 plot_width = width - left - right 284 plot_height = height - top - bottom 285 286 values = [ 287 float(value) 288 for entry in series 289 for value in entry[1] 290 if math.isfinite(float(value)) 291 ] 292 if not values: 293 values = [0.0, 1.0] 294 295 y_min = min(values) 296 y_max = max(values) 297 if math.isclose(y_min, y_max): 298 pad = max(1.0, abs(y_min) * 0.1 + 1.0) 299 y_min -= pad 300 y_max += pad 301 302 x_max = max((len(entry[1]) - 1 for entry in series if entry[1]), default=1) 303 x_max = max(x_max, 1) 304 305 def x_pos(index: int) -> float: 306 return left + (index / x_max) * plot_width 307 308 def y_pos(value: float) -> float: 309 norm = (value - y_min) / (y_max - y_min) 310 return top + (1.0 - norm) * plot_height 311 312 grid_lines: list[str] = [] 313 tick_labels: list[str] = [] 314 for tick in range(6): 315 frac = tick / 5 316 y_value = y_max - frac * (y_max - y_min) 317 y = top + frac * plot_height 318 grid_lines.append( 319 f'<line x1="{left}" y1="{y:.2f}" x2="{width - right}" y2="{y:.2f}" ' 320 'stroke="#d7dde5" stroke-width="1" />' 321 ) 322 tick_labels.append( 323 f'<text x="{left - 10}" y="{y + 4:.2f}" text-anchor="end" ' 324 'font-size="12" fill="#334155">' 325 f"{y_value:.2f}</text>" 326 ) 327 328 x_ticks = sorted({round(tick * x_max / 5) for tick in range(6)}) 329 x_tick_labels: list[str] = [] 330 for x_tick in x_ticks: 331 x = x_pos(int(x_tick)) 332 grid_lines.append( 333 f'<line x1="{x:.2f}" y1="{top}" x2="{x:.2f}" y2="{top + plot_height:.2f}" ' 334 'stroke="#eef2f7" stroke-width="1" />' 335 ) 336 x_tick_labels.append( 337 f'<text x="{x:.2f}" y="{top + plot_height + 20:.2f}" text-anchor="middle" ' 338 'font-size="12" fill="#334155">' 339 f"{int(x_tick)}</text>" 340 ) 341 342 polylines: list[str] = [] 343 legend_items: list[str] = [] 344 for index, entry in enumerate(series): 345 label = str(entry[0]) 346 seq = entry[1] 347 color = str(entry[2]) 348 opacity = float(entry[3]) if len(entry) >= 4 else _DEFAULT_SERIES_OPACITY 349 stroke_width = float(entry[4]) if len(entry) >= 5 else 2.5 350 points = " ".join( 351 f"{x_pos(i):.2f},{y_pos(float(value)):.2f}" 352 for i, value in enumerate(seq) 353 if math.isfinite(float(value)) 354 ) 355 if points: 356 polylines.append( 357 f'<polyline fill="none" stroke="{color}" stroke-width="{stroke_width}" ' 358 f'stroke-opacity="{opacity:.3f}" ' 359 f'stroke-linejoin="round" stroke-linecap="round" points="{points}" />' 360 ) 361 legend_y = top + 16 + 20 * index 362 legend_items.append( 363 f'<rect x="{width - right - 180}" y="{legend_y - 10}" width="14" height="14" ' 364 f'fill="{color}" fill-opacity="{max(opacity, 0.35):.3f}" rx="3" />' 365 f'<text x="{width - right - 160}" y="{legend_y + 1}" font-size="13" ' 366 'fill="#0f172a">' 367 f"{escape(label)}</text>" 368 ) 369 370 svg = f"""<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}" viewBox="0 0 {width} {height}" role="img" aria-label="{escape(title)}"> 371 <rect width="100%" height="100%" fill="#f8fafc" /> 372 <text x="{left}" y="28" font-size="22" font-weight="700" fill="#0f172a">{escape(title)}</text> 373 <rect x="{left}" y="{top}" width="{plot_width}" height="{plot_height}" fill="#ffffff" stroke="#cbd5e1" stroke-width="1.5" rx="10" /> 374 {''.join(grid_lines)} 375 {''.join(tick_labels)} 376 {''.join(x_tick_labels)} 377 {''.join(polylines)} 378 {''.join(legend_items)} 379 <text x="{left + plot_width / 2:.2f}" y="{height - 16}" text-anchor="middle" font-size="14" fill="#334155">{escape(x_label)}</text> 380 <text x="20" y="{top + plot_height / 2:.2f}" text-anchor="middle" font-size="14" fill="#334155" transform="rotate(-90 20 {top + plot_height / 2:.2f})">{escape(y_label)}</text> 381</svg> 382""" 383 path.parent.mkdir(parents=True, exist_ok=True) 384 path.write_text(svg) 385 386 387def _hex_to_rgba(color: str, alpha: float = 1.0) -> tuple[int, int, int, int]: 388 color = color.lstrip("#") 389 if len(color) != 6: 390 return (107, 114, 128, int(255 * alpha)) 391 red = int(color[0:2], 16) 392 green = int(color[2:4], 16) 393 blue = int(color[4:6], 16) 394 return (red, green, blue, max(0, min(255, int(round(255 * alpha))))) 395 396 397def _write_line_plot_png( 398 path: Path, 399 *, 400 title: str, 401 x_label: str, 402 y_label: str, 403 series: list[tuple[Any, ...]], 404) -> None: 405 """Write a lightweight PNG line chart without plotting dependencies.""" 406 width = 960 407 height = 540 408 left = 70 409 right = 24 410 top = 52 411 bottom = 58 412 plot_width = width - left - right 413 plot_height = height - top - bottom 414 415 values = [ 416 float(value) 417 for entry in series 418 for value in entry[1] 419 if math.isfinite(float(value)) 420 ] 421 if not values: 422 values = [0.0, 1.0] 423 424 y_min = min(values) 425 y_max = max(values) 426 if math.isclose(y_min, y_max): 427 pad = max(1.0, abs(y_min) * 0.1 + 1.0) 428 y_min -= pad 429 y_max += pad 430 431 x_max = max((len(entry[1]) - 1 for entry in series if entry[1]), default=1) 432 x_max = max(x_max, 1) 433 434 def x_pos(index: int) -> float: 435 return left + (index / x_max) * plot_width 436 437 def y_pos(value: float) -> float: 438 norm = (value - y_min) / (y_max - y_min) 439 return top + (1.0 - norm) * plot_height 440 441 image = Image.new("RGBA", (width, height), (248, 250, 252, 255)) 442 draw = ImageDraw.Draw(image, "RGBA") 443 font = ImageFont.load_default() 444 445 draw.rounded_rectangle( 446 (left, top, left + plot_width, top + plot_height), 447 radius=10, 448 fill=(255, 255, 255, 255), 449 outline=(203, 213, 225, 255), 450 width=2, 451 ) 452 draw.text((left, 16), title, fill=(15, 23, 42, 255), font=font) 453 454 for tick in range(6): 455 frac = tick / 5 456 y_value = y_max - frac * (y_max - y_min) 457 y = top + frac * plot_height 458 draw.line( 459 ((left, y), (width - right, y)), 460 fill=(215, 221, 229, 255), 461 width=1, 462 ) 463 draw.text( 464 (left - 52, y - 6), 465 f"{y_value:.2f}", 466 fill=(51, 65, 85, 255), 467 font=font, 468 ) 469 470 x_ticks = sorted({round(tick * x_max / 5) for tick in range(6)}) 471 for x_tick in x_ticks: 472 x = x_pos(int(x_tick)) 473 draw.line( 474 ((x, top), (x, top + plot_height)), 475 fill=(238, 242, 247, 255), 476 width=1, 477 ) 478 draw.text( 479 (x - 8, top + plot_height + 8), 480 str(int(x_tick)), 481 fill=(51, 65, 85, 255), 482 font=font, 483 ) 484 485 for entry in series: 486 seq = entry[1] 487 color = str(entry[2]) 488 opacity = float(entry[3]) if len(entry) >= 4 else _DEFAULT_SERIES_OPACITY 489 stroke_width = max(1, int(round(float(entry[4]) if len(entry) >= 5 else 2.5))) 490 points = [ 491 (x_pos(i), y_pos(float(value))) 492 for i, value in enumerate(seq) 493 if math.isfinite(float(value)) 494 ] 495 if len(points) >= 2: 496 draw.line( 497 points, 498 fill=_hex_to_rgba(color, opacity), 499 width=stroke_width, 500 joint="curve", 501 ) 502 503 legend_x = width - right - 180 504 for index, entry in enumerate(series): 505 label = str(entry[0]) 506 color = str(entry[2]) 507 opacity = float(entry[3]) if len(entry) >= 4 else _DEFAULT_SERIES_OPACITY 508 legend_y = top + 6 + 20 * index 509 draw.rounded_rectangle( 510 (legend_x, legend_y, legend_x + 14, legend_y + 14), 511 radius=3, 512 fill=_hex_to_rgba(color, max(opacity, 0.35)), 513 ) 514 draw.text( 515 (legend_x + 20, legend_y), 516 label, 517 fill=(15, 23, 42, 255), 518 font=font, 519 ) 520 521 draw.text( 522 (left + plot_width / 2 - 25, height - 24), 523 x_label, 524 fill=(51, 65, 85, 255), 525 font=font, 526 ) 527 draw.text((8, top + plot_height / 2), y_label, fill=(51, 65, 85, 255), font=font) 528 529 path.parent.mkdir(parents=True, exist_ok=True) 530 image.convert("RGB").save(path, format="PNG") 531 532 533def _write_line_plot_assets( 534 stem: Path, 535 *, 536 title: str, 537 x_label: str, 538 y_label: str, 539 series: list[tuple[Any, ...]], 540) -> None: 541 _write_line_plot_svg( 542 stem.with_suffix(".svg"), 543 title=title, 544 x_label=x_label, 545 y_label=y_label, 546 series=series, 547 ) 548 _write_line_plot_png( 549 stem.with_suffix(".png"), 550 title=title, 551 x_label=x_label, 552 y_label=y_label, 553 series=series, 554 ) 555 556 557@dataclass(slots=True) 558class TrainingCurveRecorder: 559 """Persist reward and moving-average curves for a full training run.""" 560 561 output_dir: Path 562 prefix: str = "training" 563 average_window: int = 100 564 reward_color: str = "#264653" 565 average_color: str = "#2a9d8f" 566 epsilon_color: str = "#e76f51" 567 option_color: str = "#457b9d" 568 reward_history: list[float] = field(default_factory=list) 569 average_history: list[float] = field(default_factory=list) 570 epsilon_history: list[float] = field(default_factory=list) 571 option_history: list[float] = field(default_factory=list) 572 metric_histories: dict[str, list[float]] = field(default_factory=dict) 573 checkpoint_every: int = 0 574 metric_colors: dict[str, str] = field( 575 default_factory=lambda: { 576 "perception_encoder_loss": "#8a5cf6", 577 "value_q_omega_loss": "#c0392b", 578 "value_gvf_loss": "#d35400", 579 "policy_q_loss": "#2980b9", 580 "policy_termination_loss": "#16a085", 581 "model_loss": "#7f8c8d", 582 "model_done_loss": "#34495e", 583 } 584 ) 585 586 def log_episode( 587 self, 588 episode: int, 589 reward: float, 590 avg_reward: float, 591 agent: Any, 592 ) -> None: 593 """Record one episode's aggregate metrics.""" 594 current_episode = int(episode) 595 self.reward_history.append(float(reward)) 596 self.average_history.append(float(avg_reward)) 597 self.epsilon_history.append( 598 float(getattr(agent.reactive_policy, "epsilon", 0.0)) 599 ) 600 self.option_history.append( 601 float(len(getattr(agent.reactive_policy, "_options", {}))) 602 ) 603 current_metrics = _episode_metric_snapshot(agent) 604 episode_count = len(self.reward_history) 605 for metric_name in current_metrics: 606 self.metric_histories.setdefault( 607 metric_name, 608 [float("nan")] * (episode_count - 1), 609 ) 610 for metric_name, history in self.metric_histories.items(): 611 history.append(float(current_metrics.get(metric_name, float("nan")))) 612 if self.checkpoint_every > 0 and (current_episode + 1) % self.checkpoint_every == 0: 613 self.save_checkpoint() 614 615 def _history_payload(self) -> dict[str, Any]: 616 return { 617 "average_window": self.average_window, 618 "episodes": len(self.reward_history), 619 "reward_history": self.reward_history, 620 "average_history": self.average_history, 621 "epsilon_history": self.epsilon_history, 622 "option_history": self.option_history, 623 "metric_histories": self.metric_histories, 624 "best_reward": max(self.reward_history), 625 "final_reward": self.reward_history[-1], 626 "final_average": self.average_history[-1], 627 } 628 629 def save_checkpoint(self) -> None: 630 """Write the raw training histories without rendering plots.""" 631 if not self.reward_history: 632 return 633 self.output_dir.mkdir(parents=True, exist_ok=True) 634 stem = self.output_dir / self.prefix 635 payload = self._history_payload() 636 payload["checkpoint"] = True 637 stem.with_name(f"{self.prefix}_reward_history_checkpoint.json").write_text( 638 json.dumps(payload, indent=2) + "\n" 639 ) 640 641 def save(self) -> None: 642 """Write SVG plots and the raw history JSON.""" 643 if not self.reward_history: 644 return 645 646 self.output_dir.mkdir(parents=True, exist_ok=True) 647 stem = self.output_dir / self.prefix 648 649 _write_line_plot_assets( 650 stem.with_name(f"{self.prefix}_reward_curve"), 651 title=f"{self.prefix} reward over episodes", 652 x_label="Episode", 653 y_label="Reward", 654 series=[ 655 ( 656 "reward", 657 self.reward_history, 658 self.reward_color, 659 _RAW_SERIES_OPACITY, 660 _RAW_SERIES_WIDTH, 661 ), 662 ( 663 f"avg{self.average_window}", 664 _moving_average(self.reward_history, self.average_window), 665 self.average_color, 666 _DEFAULT_SERIES_OPACITY, 667 _AVG_SERIES_WIDTH, 668 ), 669 ], 670 ) 671 _write_line_plot_assets( 672 stem.with_name(f"{self.prefix}_training_state"), 673 title=f"{self.prefix} exploration and option count", 674 x_label="Episode", 675 y_label="Value", 676 series=[ 677 ("epsilon", self.epsilon_history, self.epsilon_color), 678 ("options", self.option_history, self.option_color), 679 ], 680 ) 681 682 for metric_name, history in sorted(self.metric_histories.items()): 683 if not history: 684 continue 685 metric_label = metric_name.replace("_", " ") 686 color = self.metric_colors.get(metric_name, "#6b7280") 687 smoothed = _moving_average(history, self.average_window) 688 _write_line_plot_assets( 689 stem.with_name(f"{self.prefix}_{metric_name}"), 690 title=f"{self.prefix} {metric_label} over episodes", 691 x_label="Episode", 692 y_label=metric_label, 693 series=[ 694 ( 695 f"{metric_label} raw", 696 history, 697 color, 698 _RAW_SERIES_OPACITY, 699 _RAW_SERIES_WIDTH, 700 ), 701 ( 702 f"{metric_label} avg{self.average_window}", 703 smoothed, 704 color, 705 _DEFAULT_SERIES_OPACITY, 706 _AVG_SERIES_WIDTH, 707 ), 708 ], 709 ) 710 711 payload = self._history_payload() 712 stem.with_name(f"{self.prefix}_reward_history.json").write_text( 713 json.dumps(payload, indent=2) + "\n" 714 )
69@dataclass(slots=True, frozen=True) 70class EpisodeCaptureSchedule: 71 """Select which training episodes should produce artifacts.""" 72 73 episode_indices: tuple[int, ...] = () 74 every_n_episodes: int | None = None 75 last_n_episodes: int = 0 76 77 def should_capture(self, episode: int, total_episodes: int) -> bool: 78 if episode in self.episode_indices: 79 return True 80 if self.every_n_episodes is not None and self.every_n_episodes > 0: 81 if episode % self.every_n_episodes == 0: 82 return True 83 if self.last_n_episodes > 0 and episode >= max(total_episodes - self.last_n_episodes, 0): 84 return True 85 return False
Select which training episodes should produce artifacts.
77 def should_capture(self, episode: int, total_episodes: int) -> bool: 78 if episode in self.episode_indices: 79 return True 80 if self.every_n_episodes is not None and self.every_n_episodes > 0: 81 if episode % self.every_n_episodes == 0: 82 return True 83 if self.last_n_episodes > 0 and episode >= max(total_episodes - self.last_n_episodes, 0): 84 return True 85 return False
88@dataclass(slots=True) 89class EpisodeAnimationRecorder: 90 """Save traced episode frames as GIF animations plus compact metadata.""" 91 92 output_dir: Path 93 schedule: EpisodeCaptureSchedule = field(default_factory=EpisodeCaptureSchedule) 94 prefix: str = "training" 95 fps: int = 30 96 max_frames: int | None = None 97 save_metadata: bool = True 98 99 def should_capture(self, episode: int, total_episodes: int) -> bool: 100 return self.schedule.should_capture(episode, total_episodes) 101 102 def __call__(self, trace: EpisodeTrace[Any, Any, Any, Any]) -> None: 103 self.output_dir.mkdir(parents=True, exist_ok=True) 104 stem = self._episode_stem(trace) 105 metadata_path = self.output_dir / f"{stem}.json" 106 gif_path = self.output_dir / f"{stem}.gif" 107 108 frames = list(trace.frames) 109 if self.max_frames is not None and self.max_frames > 0: 110 frames = frames[: self.max_frames] 111 if not frames: 112 if self.save_metadata: 113 metadata = self._trace_metadata(trace) 114 metadata["warning"] = ( 115 "No frames were captured for this episode. Ensure the world was " 116 "created with render_mode='rgb_array' and the environment's " 117 "rendering dependencies are installed." 118 ) 119 metadata_path.write_text(json.dumps(metadata, indent=2) + "\n") 120 return 121 122 pil_frames = [self._to_pil_image(frame) for frame in frames] 123 pil_frames[0].save( 124 gif_path, 125 save_all=True, 126 append_images=pil_frames[1:], 127 duration=max(int(1000 / max(self.fps, 1)), 1), 128 loop=0, 129 ) 130 131 if self.save_metadata: 132 metadata_path.write_text(json.dumps(self._trace_metadata(trace), indent=2) + "\n") 133 134 def _episode_stem(self, trace: EpisodeTrace[Any, Any, Any, Any]) -> str: 135 return ( 136 f"{self.prefix}_episode_{trace.episode:04d}" 137 f"_reward_{int(round(trace.episode_reward))}" 138 f"_avg_{int(round(trace.avg_reward))}" 139 ) 140 141 def _trace_metadata( 142 self, 143 trace: EpisodeTrace[Any, Any, Any, Any], 144 ) -> dict[str, Any]: 145 return { 146 "episode": trace.episode, 147 "episode_reward": trace.episode_reward, 148 "avg_reward": trace.avg_reward, 149 "step_count": trace.step_count, 150 "solved": trace.solved, 151 "frame_count": len(trace.frames), 152 "actions": [step.action for step in trace.steps], 153 "active_option_ids": [step.active_option_id for step in trace.steps], 154 "metadata": dict(trace.metadata), 155 } 156 157 def _to_pil_image(self, frame: object) -> Image.Image: 158 if isinstance(frame, Image.Image): 159 return frame.convert("RGB") 160 161 array = np.asarray(frame) 162 if array.ndim == 2: 163 array = np.stack([array] * 3, axis=-1) 164 if array.dtype != np.uint8: 165 array = np.clip(array, 0, 255).astype(np.uint8) 166 return Image.fromarray(array).convert("RGB")
Save traced episode frames as GIF animations plus compact metadata.
169def animation_recorder_from_env(mode: str) -> EpisodeAnimationRecorder | None: 170 """Build an animation recorder from env vars with sensible defaults. 171 172 Animation capture is enabled by default for the example training entry 173 points so that long runs automatically leave behind visual artifacts. 174 Set ``OAK_EXAMPLE_DISABLE_ANIMATIONS=1`` to turn it off. 175 """ 176 177 disable_token = os.environ.get("OAK_EXAMPLE_DISABLE_ANIMATIONS", "").strip().lower() 178 if disable_token in {"1", "true", "yes", "on"}: 179 return None 180 181 output_dir = ( 182 os.environ.get("OAK_EXAMPLE_ANIMATION_DIR", "").strip() 183 or str(_DEFAULT_ANIMATION_OUTPUT_DIR) 184 ) 185 episode_tokens = os.environ.get("OAK_EXAMPLE_ANIMATION_EPISODES", "").split(",") 186 episode_indices = tuple( 187 int(token.strip()) 188 for token in episode_tokens 189 if token.strip() 190 ) 191 every_n = int( 192 os.environ.get( 193 "OAK_EXAMPLE_ANIMATION_EVERY", 194 str(_DEFAULT_ANIMATION_EVERY), 195 ) 196 or _DEFAULT_ANIMATION_EVERY 197 ) 198 last_n = int( 199 os.environ.get( 200 "OAK_EXAMPLE_ANIMATION_LAST", 201 str(_DEFAULT_ANIMATION_LAST), 202 ) 203 or _DEFAULT_ANIMATION_LAST 204 ) 205 fps = int( 206 os.environ.get( 207 "OAK_EXAMPLE_ANIMATION_FPS", 208 str(_DEFAULT_ANIMATION_FPS), 209 ) 210 or _DEFAULT_ANIMATION_FPS 211 ) 212 max_frames = int( 213 os.environ.get( 214 "OAK_EXAMPLE_ANIMATION_MAX_FRAMES", 215 str(_DEFAULT_ANIMATION_MAX_FRAMES), 216 ) 217 or _DEFAULT_ANIMATION_MAX_FRAMES 218 ) 219 220 return EpisodeAnimationRecorder( 221 output_dir=Path(output_dir) / mode, 222 schedule=EpisodeCaptureSchedule( 223 episode_indices=episode_indices, 224 every_n_episodes=every_n or None, 225 last_n_episodes=last_n, 226 ), 227 prefix=mode, 228 fps=fps, 229 max_frames=max_frames or None, 230 )
Build an animation recorder from env vars with sensible defaults.
Animation capture is enabled by default for the example training entry
points so that long runs automatically leave behind visual artifacts.
Set OAK_EXAMPLE_DISABLE_ANIMATIONS=1 to turn it off.
233def curve_recorder_from_env( 234 mode: str, 235 *, 236 average_window: int, 237) -> TrainingCurveRecorder | None: 238 """Build a curve recorder from env vars with sensible defaults. 239 240 Curve capture is enabled by default so short training runs still write the 241 SVG/JSON artifacts used in the report. Set 242 ``OAK_EXAMPLE_DISABLE_CURVES=1`` to turn it off. 243 """ 244 245 disable_token = os.environ.get("OAK_EXAMPLE_DISABLE_CURVES", "").strip().lower() 246 if disable_token in {"1", "true", "yes", "on"}: 247 return None 248 249 output_dir = ( 250 os.environ.get("OAK_EXAMPLE_CURVE_DIR", "").strip() 251 or os.environ.get("OAK_EXAMPLE_PLOT_DIR", "").strip() 252 or str(_DEFAULT_CURVE_OUTPUT_DIR) 253 ) 254 checkpoint_every = int( 255 os.environ.get( 256 "OAK_EXAMPLE_CURVE_CHECKPOINT_EVERY", 257 str(_DEFAULT_CURVE_CHECKPOINT_EVERY), 258 ) 259 or _DEFAULT_CURVE_CHECKPOINT_EVERY 260 ) 261 return TrainingCurveRecorder( 262 output_dir=Path(output_dir) / mode, 263 prefix=mode, 264 average_window=average_window, 265 checkpoint_every=checkpoint_every, 266 )
Build a curve recorder from env vars with sensible defaults.
Curve capture is enabled by default so short training runs still write the
SVG/JSON artifacts used in the report. Set
OAK_EXAMPLE_DISABLE_CURVES=1 to turn it off.
558@dataclass(slots=True) 559class TrainingCurveRecorder: 560 """Persist reward and moving-average curves for a full training run.""" 561 562 output_dir: Path 563 prefix: str = "training" 564 average_window: int = 100 565 reward_color: str = "#264653" 566 average_color: str = "#2a9d8f" 567 epsilon_color: str = "#e76f51" 568 option_color: str = "#457b9d" 569 reward_history: list[float] = field(default_factory=list) 570 average_history: list[float] = field(default_factory=list) 571 epsilon_history: list[float] = field(default_factory=list) 572 option_history: list[float] = field(default_factory=list) 573 metric_histories: dict[str, list[float]] = field(default_factory=dict) 574 checkpoint_every: int = 0 575 metric_colors: dict[str, str] = field( 576 default_factory=lambda: { 577 "perception_encoder_loss": "#8a5cf6", 578 "value_q_omega_loss": "#c0392b", 579 "value_gvf_loss": "#d35400", 580 "policy_q_loss": "#2980b9", 581 "policy_termination_loss": "#16a085", 582 "model_loss": "#7f8c8d", 583 "model_done_loss": "#34495e", 584 } 585 ) 586 587 def log_episode( 588 self, 589 episode: int, 590 reward: float, 591 avg_reward: float, 592 agent: Any, 593 ) -> None: 594 """Record one episode's aggregate metrics.""" 595 current_episode = int(episode) 596 self.reward_history.append(float(reward)) 597 self.average_history.append(float(avg_reward)) 598 self.epsilon_history.append( 599 float(getattr(agent.reactive_policy, "epsilon", 0.0)) 600 ) 601 self.option_history.append( 602 float(len(getattr(agent.reactive_policy, "_options", {}))) 603 ) 604 current_metrics = _episode_metric_snapshot(agent) 605 episode_count = len(self.reward_history) 606 for metric_name in current_metrics: 607 self.metric_histories.setdefault( 608 metric_name, 609 [float("nan")] * (episode_count - 1), 610 ) 611 for metric_name, history in self.metric_histories.items(): 612 history.append(float(current_metrics.get(metric_name, float("nan")))) 613 if self.checkpoint_every > 0 and (current_episode + 1) % self.checkpoint_every == 0: 614 self.save_checkpoint() 615 616 def _history_payload(self) -> dict[str, Any]: 617 return { 618 "average_window": self.average_window, 619 "episodes": len(self.reward_history), 620 "reward_history": self.reward_history, 621 "average_history": self.average_history, 622 "epsilon_history": self.epsilon_history, 623 "option_history": self.option_history, 624 "metric_histories": self.metric_histories, 625 "best_reward": max(self.reward_history), 626 "final_reward": self.reward_history[-1], 627 "final_average": self.average_history[-1], 628 } 629 630 def save_checkpoint(self) -> None: 631 """Write the raw training histories without rendering plots.""" 632 if not self.reward_history: 633 return 634 self.output_dir.mkdir(parents=True, exist_ok=True) 635 stem = self.output_dir / self.prefix 636 payload = self._history_payload() 637 payload["checkpoint"] = True 638 stem.with_name(f"{self.prefix}_reward_history_checkpoint.json").write_text( 639 json.dumps(payload, indent=2) + "\n" 640 ) 641 642 def save(self) -> None: 643 """Write SVG plots and the raw history JSON.""" 644 if not self.reward_history: 645 return 646 647 self.output_dir.mkdir(parents=True, exist_ok=True) 648 stem = self.output_dir / self.prefix 649 650 _write_line_plot_assets( 651 stem.with_name(f"{self.prefix}_reward_curve"), 652 title=f"{self.prefix} reward over episodes", 653 x_label="Episode", 654 y_label="Reward", 655 series=[ 656 ( 657 "reward", 658 self.reward_history, 659 self.reward_color, 660 _RAW_SERIES_OPACITY, 661 _RAW_SERIES_WIDTH, 662 ), 663 ( 664 f"avg{self.average_window}", 665 _moving_average(self.reward_history, self.average_window), 666 self.average_color, 667 _DEFAULT_SERIES_OPACITY, 668 _AVG_SERIES_WIDTH, 669 ), 670 ], 671 ) 672 _write_line_plot_assets( 673 stem.with_name(f"{self.prefix}_training_state"), 674 title=f"{self.prefix} exploration and option count", 675 x_label="Episode", 676 y_label="Value", 677 series=[ 678 ("epsilon", self.epsilon_history, self.epsilon_color), 679 ("options", self.option_history, self.option_color), 680 ], 681 ) 682 683 for metric_name, history in sorted(self.metric_histories.items()): 684 if not history: 685 continue 686 metric_label = metric_name.replace("_", " ") 687 color = self.metric_colors.get(metric_name, "#6b7280") 688 smoothed = _moving_average(history, self.average_window) 689 _write_line_plot_assets( 690 stem.with_name(f"{self.prefix}_{metric_name}"), 691 title=f"{self.prefix} {metric_label} over episodes", 692 x_label="Episode", 693 y_label=metric_label, 694 series=[ 695 ( 696 f"{metric_label} raw", 697 history, 698 color, 699 _RAW_SERIES_OPACITY, 700 _RAW_SERIES_WIDTH, 701 ), 702 ( 703 f"{metric_label} avg{self.average_window}", 704 smoothed, 705 color, 706 _DEFAULT_SERIES_OPACITY, 707 _AVG_SERIES_WIDTH, 708 ), 709 ], 710 ) 711 712 payload = self._history_payload() 713 stem.with_name(f"{self.prefix}_reward_history.json").write_text( 714 json.dumps(payload, indent=2) + "\n" 715 )
Persist reward and moving-average curves for a full training run.
587 def log_episode( 588 self, 589 episode: int, 590 reward: float, 591 avg_reward: float, 592 agent: Any, 593 ) -> None: 594 """Record one episode's aggregate metrics.""" 595 current_episode = int(episode) 596 self.reward_history.append(float(reward)) 597 self.average_history.append(float(avg_reward)) 598 self.epsilon_history.append( 599 float(getattr(agent.reactive_policy, "epsilon", 0.0)) 600 ) 601 self.option_history.append( 602 float(len(getattr(agent.reactive_policy, "_options", {}))) 603 ) 604 current_metrics = _episode_metric_snapshot(agent) 605 episode_count = len(self.reward_history) 606 for metric_name in current_metrics: 607 self.metric_histories.setdefault( 608 metric_name, 609 [float("nan")] * (episode_count - 1), 610 ) 611 for metric_name, history in self.metric_histories.items(): 612 history.append(float(current_metrics.get(metric_name, float("nan")))) 613 if self.checkpoint_every > 0 and (current_episode + 1) % self.checkpoint_every == 0: 614 self.save_checkpoint()
Record one episode's aggregate metrics.
630 def save_checkpoint(self) -> None: 631 """Write the raw training histories without rendering plots.""" 632 if not self.reward_history: 633 return 634 self.output_dir.mkdir(parents=True, exist_ok=True) 635 stem = self.output_dir / self.prefix 636 payload = self._history_payload() 637 payload["checkpoint"] = True 638 stem.with_name(f"{self.prefix}_reward_history_checkpoint.json").write_text( 639 json.dumps(payload, indent=2) + "\n" 640 )
Write the raw training histories without rendering plots.
642 def save(self) -> None: 643 """Write SVG plots and the raw history JSON.""" 644 if not self.reward_history: 645 return 646 647 self.output_dir.mkdir(parents=True, exist_ok=True) 648 stem = self.output_dir / self.prefix 649 650 _write_line_plot_assets( 651 stem.with_name(f"{self.prefix}_reward_curve"), 652 title=f"{self.prefix} reward over episodes", 653 x_label="Episode", 654 y_label="Reward", 655 series=[ 656 ( 657 "reward", 658 self.reward_history, 659 self.reward_color, 660 _RAW_SERIES_OPACITY, 661 _RAW_SERIES_WIDTH, 662 ), 663 ( 664 f"avg{self.average_window}", 665 _moving_average(self.reward_history, self.average_window), 666 self.average_color, 667 _DEFAULT_SERIES_OPACITY, 668 _AVG_SERIES_WIDTH, 669 ), 670 ], 671 ) 672 _write_line_plot_assets( 673 stem.with_name(f"{self.prefix}_training_state"), 674 title=f"{self.prefix} exploration and option count", 675 x_label="Episode", 676 y_label="Value", 677 series=[ 678 ("epsilon", self.epsilon_history, self.epsilon_color), 679 ("options", self.option_history, self.option_color), 680 ], 681 ) 682 683 for metric_name, history in sorted(self.metric_histories.items()): 684 if not history: 685 continue 686 metric_label = metric_name.replace("_", " ") 687 color = self.metric_colors.get(metric_name, "#6b7280") 688 smoothed = _moving_average(history, self.average_window) 689 _write_line_plot_assets( 690 stem.with_name(f"{self.prefix}_{metric_name}"), 691 title=f"{self.prefix} {metric_label} over episodes", 692 x_label="Episode", 693 y_label=metric_label, 694 series=[ 695 ( 696 f"{metric_label} raw", 697 history, 698 color, 699 _RAW_SERIES_OPACITY, 700 _RAW_SERIES_WIDTH, 701 ), 702 ( 703 f"{metric_label} avg{self.average_window}", 704 smoothed, 705 color, 706 _DEFAULT_SERIES_OPACITY, 707 _AVG_SERIES_WIDTH, 708 ), 709 ], 710 ) 711 712 payload = self._history_payload() 713 stem.with_name(f"{self.prefix}_reward_history.json").write_text( 714 json.dumps(payload, indent=2) + "\n" 715 )
Write SVG plots and the raw history JSON.