oak

 1from .agent import OaKAgent
 2from . import fine_grained
 3from .interfaces import (
 4    ContinualLearner,
 5    Perception,
 6    ReactivePolicy,
 7    TransitionModel,
 8    ValueFunction,
 9    World,
10)
11from .types import (
12    AgentStepResult,
13    ComponentKind,
14    CurationDecision,
15    FeatureCandidate,
16    FeatureSpec,
17    GeneralValueFunctionSpec,
18    ModelPrediction,
19    OptionDescriptor,
20    PlanningUpdate,
21    PolicyDecision,
22    SubtaskSpec,
23    TimeStep,
24    Transition,
25    UsageRecord,
26    UtilityRecord,
27)
28
29__all__ = [
30    # ── Continual-learning mixin ──
31    "ContinualLearner",
32    # ── The four main OaK interfaces ──
33    "Perception",
34    "TransitionModel",
35    "ValueFunction",
36    "ReactivePolicy",
37    # ── Agent ──
38    "OaKAgent",
39    # ── Environment ──
40    "World",
41    # ── Optional advanced assembly layer ──
42    "fine_grained",
43    # ── Shared types ──
44    "AgentStepResult",
45    "ComponentKind",
46    "CurationDecision",
47    "FeatureCandidate",
48    "FeatureSpec",
49    "GeneralValueFunctionSpec",
50    "ModelPrediction",
51    "OptionDescriptor",
52    "PlanningUpdate",
53    "PolicyDecision",
54    "SubtaskSpec",
55    "TimeStep",
56    "Transition",
57    "UsageRecord",
58    "UtilityRecord",
59]

Architecture Guide

Main four-interface OaK view showing Perception, ValueFunction, TransitionModel, and ReactivePolicy.
`oak_core`. The default conceptual slot map: OaKAgent coordinating the four main interfaces and their main data flow.
Fine-grained slot map showing Composite modules, their delegated interfaces, and associated optional interfaces.
`oak_architecture`. The fine-grained slot map: Composite modules plus the lower-level interfaces available inside each slot.
Simplified runtime sequence for the six phases of OaKAgent.step(...).
`oak_runtime_overview`. The top-level step path through the four main interfaces for the six phases: Perceive, Learn, Grow, Plan, Act, Maintain.
Detailed runtime sequence showing the fine-grained interfaces actually touched during one OaKAgent step through Composite modules.
`oak_runtime_sequence`. The detailed composite-wired step path: OaKAgent -> Composite* -> fine_grained interface used during that step.

What You Must Implement

OaKAgent is the canonical coordinator. It is composed of exactly four objects, one per Sutton module:

You also configure scalar controls:

OaKAgent manages these runtime fields itself:

Your environment must implement the World protocol (reset, step, close) to use OaKAgent.train(). You can also drive the loop yourself by supplying TimeStep objects to OaKAgent.step(...) directly.

Two Ways to Implement

Direct approach: implement the four main interfaces directly. Each of your classes is a self-contained module. This is the simplest path and what the examples/smoke/minimal_oak.py example demonstrates.

Composite approach: use the fine-grained component interfaces from oak.fine_grained.components and wire them together using the composites from oak.fine_grained.composites. This is for projects that need to independently swap building blocks inside a module (e.g. replace the planner without touching the world model). The examples/smoke/minimal_oak_fine_grained.py example demonstrates this path with the same toy behavior as the direct example.

Main interface Composite class Fine-grained building blocks
Perception CompositePerception StateBuilder, FeatureBank, FeatureConstructor, FeatureRanker, SubtaskGenerator
TransitionModel CompositeTransitionModel WorldModel, OptionModelLearner, OptionModel, Planner
ValueFunction CompositeValueFunction ValueEstimator, GeneralValueFunctionLearner, UtilityAssessor, Curator, MetaStepSizeLearner
ReactivePolicy CompositeReactivePolicy ActionSelector, Option, OptionLibrary, OptionLearner, OptionKeyboard (optional)

Diagram-to-Code Mapping

The diagrams have different jobs, but they all describe the same implementation:

Recommended reading order for the diagrams:

  1. Read oak_core to understand the default four-interface surface.
  2. Read oak_runtime_overview for the six phases of step(...).
  3. Read oak_architecture to see how the optional fine-grained layer is assembled.
  4. Read oak_runtime_sequence to trace one composite-wired execution path.

oak_runtime_overview and oak_runtime_sequence describe the same six phases. The difference is only the level of expansion: oak_runtime_overview stays at the four-interface layer, while oak_runtime_sequence shows what happens when those slots are filled by the Composite* implementations from oak.fine_grained.composites. If either diagram and the code ever disagree, the documentation should be fixed.

The diagrams are intentionally runtime-oriented. They are not exhaustive method inventories for the interfaces. For the full surface area (reset, predict, current_subjective_state, OptionKeyboard, and so on), use the API reference below. oak_architecture is the broadest inventory view; oak_runtime_overview and oak_runtime_sequence are narrower and only show what matters for one OaKAgent.step(...).

Step Walkthrough

Read the method as a pipeline. Each block below corresponds to the next block of code in OaKAgent.step(...).

1. Perceive

subjective_state = self.perception.update(...)

time_step is the input. It carries observation, reward, terminated, truncated, and optional info. perception must turn these into the current subjective_state. Every later call in the step uses this subjective_state, so your Perception implementation defines what the agent actually reasons over.

2. Learn

td_errors = self.value_function.update(transition)
self.reactive_policy.update(transition, td_errors)
self.transition_model.update(transition)

Learning starts only once the agent has both a previous subjective_state and a previous action. The first call to step(...) therefore sets up memory but cannot yet build a full transition.

The Transition packages the previous/next subjective states, the action, reward, the termination outcome, and optional info. All three modules receive it. value_function.update returns TD errors that reactive_policy.update uses for policy improvement.

3. Grow

ranked_feature_ids = self.perception.discover_and_rank_features(...)
created_subtasks = self.perception.generate_subtasks(ranked_feature_ids)
self.reactive_policy.ingest_subtasks(created_subtasks)
self.reactive_policy.integrate_options()
self.transition_model.integrate_option_models()

perception proposes new features, ranks them by utility, and generates subtasks from the most useful ones. reactive_policy turns subtasks into options. transition_model integrates the latest option models so planning can reason about them. In the overview diagram this appears as top-level module calls; in the detailed sequence diagram the same phase is expanded into Composite* -> fine_grained interface calls.

4. Plan

planning_update = self.transition_model.plan(
    subjective_state, self.value_function, self.planning_budget
)
self.reactive_policy.apply_planning_update(planning_update)

transition_model.plan(...) receives the current subjective_state, the value_function (for state evaluation during search), and a budget. It returns a PlanningUpdate. reactive_policy is informed about the planner's output before action selection.

5. Act

action, active_option_id = self.reactive_policy.select_action(
    subjective_state, self.option_stop_threshold
)

The reactive policy either continues an active option or makes a fresh decision. The output is always a primitive action, because that is what the caller receives in AgentStepResult.

6. Maintain

self.value_function.observe_usage(usage_records)
curation_decision = self.value_function.curate()
self._apply_curation(curation_decision)

Usage records for ranked features and the active option are sent to the value function for utility tracking. The value function then decides what to prune. _apply_curation(...) dispatches the decision to the relevant modules: perception.remove_features(...), reactive_policy.remove_options(...), reactive_policy.remove_subtasks(...), transition_model.remove_option_models(...), value_function.remove(...).

Training Loop

OaKAgent.train() provides a standard episode loop so implementations don't need to rewrite the reset/step/terminate boilerplate:

agent = build_my_agent()
world = MyWorld()          # must implement the World protocol

def log_episode(episode, reward, avg_reward, agent):
    if episode % 10 == 0:
        print(f"episode={episode} reward={reward:.1f} avg={avg_reward:.1f}")

rewards = agent.train(
    world,
    num_episodes=500,
    solved_threshold=475.0,  # optional early stopping
    episode_logger=log_episode,
)
world.close()

The World protocol requires three methods:

If you need custom per-episode logging, pass episode_logger(...). If you need a fully custom training loop (non-episodic environments, multi-agent setups, custom control flow), call agent.step(time_step) directly instead.

Implementation Order

If your goal is to get a working agent quickly, implement in this order:

  1. Make Perception produce a useful subjective_state from TimeStep. Have discover_and_rank_features return a fixed list and generate_subtasks return empty.
  2. Make ReactivePolicy return valid actions from select_action. Have the other methods be no-ops.
  3. Make ValueFunction accept update and return predict values. Have curate return an empty CurationDecision.
  4. Make TransitionModel accept update and return a valid PlanningUpdate from plan, even if trivial.

That is enough to satisfy the exact call sequence of OaKAgent.step(...). After that, you can improve learning quality without changing the basic wiring.

Repository Examples

The concrete implementations live outside oak on purpose. That shows the intended usage pattern: the package provides the canonical OaKAgent coordinator and interfaces, while downstream code provides the implementations. The generated docs now include the repository-level examples package alongside the core oak API.

To run all repository smoke tests, including the minimal example:

pixi run tests

To inspect the smallest runnable example directly:

from examples.smoke.minimal_oak import build_minimal_agent, run_minimal_episode

agent = build_minimal_agent()
trace = run_minimal_episode(horizon=5)

run_minimal_episode(...) returns a compact trace with the subjective_state, primitive action, active_option_id, created_subtasks, and planner output at each step.

Design Constraints

Keep these constraints in mind when you replace the minimal pieces with real ones:


API Documentation

class ContinualLearner:
 99class ContinualLearner:
100    """Mixin for modules whose weights are adapted by meta-learned step sizes.
101
102    In Sutton's OaK architecture, every learned weight has a dedicated
103    step-size parameter adapted via online cross-validation (e.g. IDBD,
104    Sutton 1992; Adam-IDBD, Degris et al. 2024).
105
106    The agent loop calls `update_meta()` on all four modules after each
107    learning step, passing the same error-signals dict.  Each module
108    internally decides which signals are relevant and routes them to its
109    per-weight step-size adaptation.
110
111    The default implementation is a no-op so that modules without
112    meta-learning still work unchanged.
113    """
114
115    def update_meta(self, error_signals: Mapping[str, float]) -> None:
116        """Adapt internal per-weight step sizes given error signals.
117
118        Parameters
119        ----------
120        error_signals:
121            Named scalar error signals from the current learning step,
122            e.g. `{"main_td_error": 0.05, "reward": 1.0}`.
123            Implementations pick the signals they need and ignore the rest.
124        """

Mixin for modules whose weights are adapted by meta-learned step sizes.

In Sutton's OaK architecture, every learned weight has a dedicated step-size parameter adapted via online cross-validation (e.g. IDBD, Sutton 1992; Adam-IDBD, Degris et al. 2024).

The agent loop calls update_meta() on all four modules after each learning step, passing the same error-signals dict. Each module internally decides which signals are relevant and routes them to its per-weight step-size adaptation.

The default implementation is a no-op so that modules without meta-learning still work unchanged.

def update_meta(self, error_signals: 'Mapping[str, float]') -> 'None':
115    def update_meta(self, error_signals: Mapping[str, float]) -> None:
116        """Adapt internal per-weight step sizes given error signals.
117
118        Parameters
119        ----------
120        error_signals:
121            Named scalar error signals from the current learning step,
122            e.g. `{"main_td_error": 0.05, "reward": 1.0}`.
123            Implementations pick the signals they need and ignore the rest.
124        """

Adapt internal per-weight step sizes given error signals.

Parameters

error_signals: Named scalar error signals from the current learning step, e.g. {"main_td_error": 0.05, "reward": 1.0}. Implementations pick the signals they need and ignore the rest.

class Perception(oak.ContinualLearner, abc.ABC, typing.Generic[~ObservationT, ~ActionT, ~SubjectiveStateT]):
132class Perception(
133    ContinualLearner, ABC, Generic[ObservationT, ActionT, SubjectiveStateT]
134):
135    """Sutton's Perception: observations → subjective state + feature management.
136
137    Turns raw observations into the agent's **subjective state**, the
138    internal representation that every other module sees.  Also discovers,
139    ranks, and manages **features** (learned representational structures
140    that grow over the agent's lifetime) and generates **subtasks** from
141    the most useful ones.
142
143    The finer-grained layer splits this into `StateBuilder`,
144    `FeatureBank`, `FeatureConstructor`, `FeatureRanker`, and
145    `SubtaskGenerator` (see `oak.fine_grained.components`).
146    """
147
148    @abstractmethod
149    def reset(self) -> None:
150        """Reset all perception state for a new episode."""
151        raise NotImplementedError
152
153    @abstractmethod
154    def update(
155        self,
156        observation: ObservationT,
157        reward: float,
158        last_action: ActionT | None,
159    ) -> SubjectiveStateT:
160        """Process a new observation and return the updated subjective state."""
161        raise NotImplementedError
162
163    @abstractmethod
164    def current_subjective_state(self) -> SubjectiveStateT:
165        """Return the most recently computed subjective state."""
166        raise NotImplementedError
167
168    @abstractmethod
169    def discover_and_rank_features(
170        self,
171        subjective_state: SubjectiveStateT,
172        utility_scores: Sequence[UtilityRecord],
173        feature_budget: int,
174    ) -> Sequence[FeatureId]:
175        """Propose new features, integrate them, and return the top-ranked IDs.
176
177        A typical implementation:
178
179        1. Proposes candidate features from the current subjective state.
180        2. Adds accepted candidates to its internal feature store.
181        3. Ranks all features using the provided utility scores.
182        4. Returns the top feature IDs (up to *feature_budget*).
183        """
184        raise NotImplementedError
185
186    @abstractmethod
187    def generate_subtasks(
188        self,
189        ranked_feature_ids: Sequence[FeatureId],
190    ) -> Sequence[SubtaskSpec]:
191        """Turn ranked feature IDs into subtask specifications."""
192        raise NotImplementedError
193
194    @abstractmethod
195    def list_features(self) -> Sequence[FeatureSpec]:
196        """Return all currently tracked features."""
197        raise NotImplementedError
198
199    @abstractmethod
200    def remove_features(self, feature_ids: Sequence[FeatureId]) -> None:
201        """Remove features by ID (called during curation)."""
202        raise NotImplementedError

Sutton's Perception: observations → subjective state + feature management.

Turns raw observations into the agent's subjective state, the internal representation that every other module sees. Also discovers, ranks, and manages features (learned representational structures that grow over the agent's lifetime) and generates subtasks from the most useful ones.

The finer-grained layer splits this into StateBuilder, FeatureBank, FeatureConstructor, FeatureRanker, and SubtaskGenerator (see oak.fine_grained.components).

@abstractmethod
def reset(self) -> 'None':
148    @abstractmethod
149    def reset(self) -> None:
150        """Reset all perception state for a new episode."""
151        raise NotImplementedError

Reset all perception state for a new episode.

@abstractmethod
def update( self, observation: 'ObservationT', reward: 'float', last_action: 'ActionT | None') -> 'SubjectiveStateT':
153    @abstractmethod
154    def update(
155        self,
156        observation: ObservationT,
157        reward: float,
158        last_action: ActionT | None,
159    ) -> SubjectiveStateT:
160        """Process a new observation and return the updated subjective state."""
161        raise NotImplementedError

Process a new observation and return the updated subjective state.

@abstractmethod
def current_subjective_state(self) -> 'SubjectiveStateT':
163    @abstractmethod
164    def current_subjective_state(self) -> SubjectiveStateT:
165        """Return the most recently computed subjective state."""
166        raise NotImplementedError

Return the most recently computed subjective state.

@abstractmethod
def discover_and_rank_features( self, subjective_state: 'SubjectiveStateT', utility_scores: 'Sequence[UtilityRecord]', feature_budget: 'int') -> 'Sequence[FeatureId]':
168    @abstractmethod
169    def discover_and_rank_features(
170        self,
171        subjective_state: SubjectiveStateT,
172        utility_scores: Sequence[UtilityRecord],
173        feature_budget: int,
174    ) -> Sequence[FeatureId]:
175        """Propose new features, integrate them, and return the top-ranked IDs.
176
177        A typical implementation:
178
179        1. Proposes candidate features from the current subjective state.
180        2. Adds accepted candidates to its internal feature store.
181        3. Ranks all features using the provided utility scores.
182        4. Returns the top feature IDs (up to *feature_budget*).
183        """
184        raise NotImplementedError

Propose new features, integrate them, and return the top-ranked IDs.

A typical implementation:

  1. Proposes candidate features from the current subjective state.
  2. Adds accepted candidates to its internal feature store.
  3. Ranks all features using the provided utility scores.
  4. Returns the top feature IDs (up to feature_budget).
@abstractmethod
def generate_subtasks( self, ranked_feature_ids: 'Sequence[FeatureId]') -> 'Sequence[SubtaskSpec]':
186    @abstractmethod
187    def generate_subtasks(
188        self,
189        ranked_feature_ids: Sequence[FeatureId],
190    ) -> Sequence[SubtaskSpec]:
191        """Turn ranked feature IDs into subtask specifications."""
192        raise NotImplementedError

Turn ranked feature IDs into subtask specifications.

@abstractmethod
def list_features(self) -> 'Sequence[FeatureSpec]':
194    @abstractmethod
195    def list_features(self) -> Sequence[FeatureSpec]:
196        """Return all currently tracked features."""
197        raise NotImplementedError

Return all currently tracked features.

@abstractmethod
def remove_features(self, feature_ids: 'Sequence[FeatureId]') -> 'None':
199    @abstractmethod
200    def remove_features(self, feature_ids: Sequence[FeatureId]) -> None:
201        """Remove features by ID (called during curation)."""
202        raise NotImplementedError

Remove features by ID (called during curation).

class TransitionModel(oak.ContinualLearner, abc.ABC, typing.Generic[~SubjectiveStateT, ~ActionT, ~InfoT]):
261class TransitionModel(ContinualLearner, ABC, Generic[SubjectiveStateT, ActionT, InfoT]):
262    """Sutton's Transition Model: world dynamics + option models + planning.
263
264    Learns from observed transitions, maintains **option models** that
265    predict the effect of temporal abstractions, and runs bounded
266    **planning** using the world model and the value function to produce
267    improvement signals for the reactive policy.
268
269    The finer-grained layer splits this into `WorldModel`,
270    `OptionModelLearner`, individual `OptionModel` objects, and a
271    `Planner` (see `oak.fine_grained.components`).
272    """
273
274    @abstractmethod
275    def update(
276        self,
277        transition: Transition[ActionT, SubjectiveStateT, InfoT],
278    ) -> None:
279        """Learn from an observed transition.
280
281        This should update both the world model and any option-model learners.
282        """
283        raise NotImplementedError
284
285    @abstractmethod
286    def integrate_option_models(self) -> None:
287        """Export learned option models and integrate them into the world model.
288
289        Called after option learning so that planning reasons over fresh models.
290        """
291        raise NotImplementedError
292
293    @abstractmethod
294    def plan(
295        self,
296        subjective_state: SubjectiveStateT,
297        value_function: ValueFunction[SubjectiveStateT, ActionT, InfoT],
298        budget: int,
299    ) -> PlanningUpdate[ActionT]:
300        """Run bounded planning and return improvement signals.
301
302        The planner uses the internal world model together with the supplied
303        *value_function* (for state evaluation) to produce value targets,
304        policy targets, or search statistics.
305        """
306        raise NotImplementedError
307
308    @abstractmethod
309    def remove_option_models(self, option_ids: Sequence[OptionId]) -> None:
310        """Remove option models by ID (called during curation)."""
311        raise NotImplementedError

Sutton's Transition Model: world dynamics + option models + planning.

Learns from observed transitions, maintains option models that predict the effect of temporal abstractions, and runs bounded planning using the world model and the value function to produce improvement signals for the reactive policy.

The finer-grained layer splits this into WorldModel, OptionModelLearner, individual OptionModel objects, and a Planner (see oak.fine_grained.components).

@abstractmethod
def update( self, transition: 'Transition[ActionT, SubjectiveStateT, InfoT]') -> 'None':
274    @abstractmethod
275    def update(
276        self,
277        transition: Transition[ActionT, SubjectiveStateT, InfoT],
278    ) -> None:
279        """Learn from an observed transition.
280
281        This should update both the world model and any option-model learners.
282        """
283        raise NotImplementedError

Learn from an observed transition.

This should update both the world model and any option-model learners.

@abstractmethod
def integrate_option_models(self) -> 'None':
285    @abstractmethod
286    def integrate_option_models(self) -> None:
287        """Export learned option models and integrate them into the world model.
288
289        Called after option learning so that planning reasons over fresh models.
290        """
291        raise NotImplementedError

Export learned option models and integrate them into the world model.

Called after option learning so that planning reasons over fresh models.

@abstractmethod
def plan( self, subjective_state: 'SubjectiveStateT', value_function: 'ValueFunction[SubjectiveStateT, ActionT, InfoT]', budget: 'int') -> 'PlanningUpdate[ActionT]':
293    @abstractmethod
294    def plan(
295        self,
296        subjective_state: SubjectiveStateT,
297        value_function: ValueFunction[SubjectiveStateT, ActionT, InfoT],
298        budget: int,
299    ) -> PlanningUpdate[ActionT]:
300        """Run bounded planning and return improvement signals.
301
302        The planner uses the internal world model together with the supplied
303        *value_function* (for state evaluation) to produce value targets,
304        policy targets, or search statistics.
305        """
306        raise NotImplementedError

Run bounded planning and return improvement signals.

The planner uses the internal world model together with the supplied value_function (for state evaluation) to produce value targets, policy targets, or search statistics.

@abstractmethod
def remove_option_models(self, option_ids: 'Sequence[OptionId]') -> 'None':
308    @abstractmethod
309    def remove_option_models(self, option_ids: Sequence[OptionId]) -> None:
310        """Remove option models by ID (called during curation)."""
311        raise NotImplementedError

Remove option models by ID (called during curation).

class ValueFunction(oak.ContinualLearner, abc.ABC, typing.Generic[~SubjectiveStateT, ~ActionT, ~InfoT]):
205class ValueFunction(ContinualLearner, ABC, Generic[SubjectiveStateT, ActionT, InfoT]):
206    """Sutton's Value Function: value learning + utility assessment + curation.
207
208    Learns **predictive value signals** (TD errors, GVF predictions) from
209    observed transitions and predicts cumulative signals for any given
210    subjective state.  Also assesses the **utility** of the agent's learned
211    structures (features, options, models) and produces concrete keep/drop
212    **curation** decisions.
213
214    The finer-grained layer splits this into `ValueEstimator`,
215    `GeneralValueFunctionLearner`, `UtilityAssessor`, `Curator`, and
216    `MetaStepSizeLearner` (see `oak.fine_grained.components`).
217    """
218
219    @abstractmethod
220    def update(
221        self,
222        transition: Transition[ActionT, SubjectiveStateT, InfoT],
223        *,
224        planning: bool = False,
225    ) -> Mapping[GeneralValueFunctionId, float]:
226        """Learn from a transition and return TD-error signals."""
227        raise NotImplementedError
228
229    @abstractmethod
230    def predict(
231        self,
232        subjective_state: SubjectiveStateT,
233    ) -> Mapping[GeneralValueFunctionId, float]:
234        """Predict values for the given subjective state."""
235        raise NotImplementedError
236
237    @abstractmethod
238    def observe_usage(self, usage_records: Sequence[UsageRecord]) -> None:
239        """Record usage evidence for utility assessment."""
240        raise NotImplementedError
241
242    @abstractmethod
243    def utility_scores(self) -> Sequence[UtilityRecord]:
244        """Return current utility estimates for all tracked structures."""
245        raise NotImplementedError
246
247    @abstractmethod
248    def curate(self) -> CurationDecision:
249        """Decide which learned structures to drop."""
250        raise NotImplementedError
251
252    @abstractmethod
253    def remove(
254        self,
255        general_value_function_ids: Sequence[GeneralValueFunctionId],
256    ) -> None:
257        """Remove value functions by ID (called during curation)."""
258        raise NotImplementedError

Sutton's Value Function: value learning + utility assessment + curation.

Learns predictive value signals (TD errors, GVF predictions) from observed transitions and predicts cumulative signals for any given subjective state. Also assesses the utility of the agent's learned structures (features, options, models) and produces concrete keep/drop curation decisions.

The finer-grained layer splits this into ValueEstimator, GeneralValueFunctionLearner, UtilityAssessor, Curator, and MetaStepSizeLearner (see oak.fine_grained.components).

@abstractmethod
def update( self, transition: 'Transition[ActionT, SubjectiveStateT, InfoT]', *, planning: 'bool' = False) -> 'Mapping[GeneralValueFunctionId, float]':
219    @abstractmethod
220    def update(
221        self,
222        transition: Transition[ActionT, SubjectiveStateT, InfoT],
223        *,
224        planning: bool = False,
225    ) -> Mapping[GeneralValueFunctionId, float]:
226        """Learn from a transition and return TD-error signals."""
227        raise NotImplementedError

Learn from a transition and return TD-error signals.

@abstractmethod
def predict( self, subjective_state: 'SubjectiveStateT') -> 'Mapping[GeneralValueFunctionId, float]':
229    @abstractmethod
230    def predict(
231        self,
232        subjective_state: SubjectiveStateT,
233    ) -> Mapping[GeneralValueFunctionId, float]:
234        """Predict values for the given subjective state."""
235        raise NotImplementedError

Predict values for the given subjective state.

@abstractmethod
def observe_usage(self, usage_records: 'Sequence[UsageRecord]') -> 'None':
237    @abstractmethod
238    def observe_usage(self, usage_records: Sequence[UsageRecord]) -> None:
239        """Record usage evidence for utility assessment."""
240        raise NotImplementedError

Record usage evidence for utility assessment.

@abstractmethod
def utility_scores(self) -> 'Sequence[UtilityRecord]':
242    @abstractmethod
243    def utility_scores(self) -> Sequence[UtilityRecord]:
244        """Return current utility estimates for all tracked structures."""
245        raise NotImplementedError

Return current utility estimates for all tracked structures.

@abstractmethod
def curate(self) -> 'CurationDecision':
247    @abstractmethod
248    def curate(self) -> CurationDecision:
249        """Decide which learned structures to drop."""
250        raise NotImplementedError

Decide which learned structures to drop.

@abstractmethod
def remove( self, general_value_function_ids: 'Sequence[GeneralValueFunctionId]') -> 'None':
252    @abstractmethod
253    def remove(
254        self,
255        general_value_function_ids: Sequence[GeneralValueFunctionId],
256    ) -> None:
257        """Remove value functions by ID (called during curation)."""
258        raise NotImplementedError

Remove value functions by ID (called during curation).

class ReactivePolicy(oak.ContinualLearner, abc.ABC, typing.Generic[~SubjectiveStateT, ~ActionT, ~InfoT]):
314class ReactivePolicy(ContinualLearner, ABC, Generic[SubjectiveStateT, ActionT, InfoT]):
315    """Sutton's Reactive Policy: action selection + option management.
316
317    Selects **actions**, primitive or temporal abstractions (options),
318    based on the current subjective state.  Manages the **option library**
319    and **option learning** pipeline, and integrates **planning updates**
320    into decision-making.
321
322    The finer-grained layer splits this into `ActionSelector`,
323    `OptionLibrary`, and `OptionLearner`
324    (see `oak.fine_grained.components`).
325    """
326
327    @abstractmethod
328    def update(
329        self,
330        transition: Transition[ActionT, SubjectiveStateT, InfoT],
331        td_errors: Mapping[GeneralValueFunctionId, float],
332    ) -> None:
333        """Update the policy and option learners from an observed transition."""
334        raise NotImplementedError
335
336    @abstractmethod
337    def apply_planning_update(self, update: PlanningUpdate[ActionT]) -> None:
338        """Integrate planning improvement signals into the policy."""
339        raise NotImplementedError
340
341    @abstractmethod
342    def ingest_subtasks(self, subtasks: Sequence[SubtaskSpec]) -> None:
343        """Feed newly created subtasks into the option learner."""
344        raise NotImplementedError
345
346    @abstractmethod
347    def integrate_options(self) -> None:
348        """Export learned options into the option library."""
349        raise NotImplementedError
350
351    @abstractmethod
352    def select_action(
353        self,
354        subjective_state: SubjectiveStateT,
355        option_stop_threshold: float,
356    ) -> tuple[ActionT, OptionId | None]:
357        """Choose a primitive action, possibly by continuing an active option.
358
359        Returns a `(primitive_action, active_option_id)` pair.  When no
360        option is active, *active_option_id* is `None`.
361        """
362        raise NotImplementedError
363
364    @abstractmethod
365    def clear_active_option(self) -> None:
366        """Clear the currently executing option (e.g. at episode boundaries)."""
367        raise NotImplementedError
368
369    @abstractmethod
370    def remove_options(self, option_ids: Sequence[OptionId]) -> None:
371        """Remove options by ID (called during curation)."""
372        raise NotImplementedError
373
374    @abstractmethod
375    def remove_subtasks(self, subtask_ids: Sequence[SubtaskId]) -> None:
376        """Remove subtasks by ID (called during curation)."""
377        raise NotImplementedError

Sutton's Reactive Policy: action selection + option management.

Selects actions, primitive or temporal abstractions (options), based on the current subjective state. Manages the option library and option learning pipeline, and integrates planning updates into decision-making.

The finer-grained layer splits this into ActionSelector, OptionLibrary, and OptionLearner (see oak.fine_grained.components).

@abstractmethod
def update( self, transition: 'Transition[ActionT, SubjectiveStateT, InfoT]', td_errors: 'Mapping[GeneralValueFunctionId, float]') -> 'None':
327    @abstractmethod
328    def update(
329        self,
330        transition: Transition[ActionT, SubjectiveStateT, InfoT],
331        td_errors: Mapping[GeneralValueFunctionId, float],
332    ) -> None:
333        """Update the policy and option learners from an observed transition."""
334        raise NotImplementedError

Update the policy and option learners from an observed transition.

@abstractmethod
def apply_planning_update(self, update: 'PlanningUpdate[ActionT]') -> 'None':
336    @abstractmethod
337    def apply_planning_update(self, update: PlanningUpdate[ActionT]) -> None:
338        """Integrate planning improvement signals into the policy."""
339        raise NotImplementedError

Integrate planning improvement signals into the policy.

@abstractmethod
def ingest_subtasks(self, subtasks: 'Sequence[SubtaskSpec]') -> 'None':
341    @abstractmethod
342    def ingest_subtasks(self, subtasks: Sequence[SubtaskSpec]) -> None:
343        """Feed newly created subtasks into the option learner."""
344        raise NotImplementedError

Feed newly created subtasks into the option learner.

@abstractmethod
def integrate_options(self) -> 'None':
346    @abstractmethod
347    def integrate_options(self) -> None:
348        """Export learned options into the option library."""
349        raise NotImplementedError

Export learned options into the option library.

@abstractmethod
def select_action( self, subjective_state: 'SubjectiveStateT', option_stop_threshold: 'float') -> 'tuple[ActionT, OptionId | None]':
351    @abstractmethod
352    def select_action(
353        self,
354        subjective_state: SubjectiveStateT,
355        option_stop_threshold: float,
356    ) -> tuple[ActionT, OptionId | None]:
357        """Choose a primitive action, possibly by continuing an active option.
358
359        Returns a `(primitive_action, active_option_id)` pair.  When no
360        option is active, *active_option_id* is `None`.
361        """
362        raise NotImplementedError

Choose a primitive action, possibly by continuing an active option.

Returns a (primitive_action, active_option_id) pair. When no option is active, active_option_id is None.

@abstractmethod
def clear_active_option(self) -> 'None':
364    @abstractmethod
365    def clear_active_option(self) -> None:
366        """Clear the currently executing option (e.g. at episode boundaries)."""
367        raise NotImplementedError

Clear the currently executing option (e.g. at episode boundaries).

@abstractmethod
def remove_options(self, option_ids: 'Sequence[OptionId]') -> 'None':
369    @abstractmethod
370    def remove_options(self, option_ids: Sequence[OptionId]) -> None:
371        """Remove options by ID (called during curation)."""
372        raise NotImplementedError

Remove options by ID (called during curation).

@abstractmethod
def remove_subtasks(self, subtask_ids: 'Sequence[SubtaskId]') -> 'None':
374    @abstractmethod
375    def remove_subtasks(self, subtask_ids: Sequence[SubtaskId]) -> None:
376        """Remove subtasks by ID (called during curation)."""
377        raise NotImplementedError

Remove subtasks by ID (called during curation).

@dataclass
class OaKAgent(typing.Generic[~ObservationT, ~ActionT, ~SubjectiveStateT, ~InfoT]):
 54@dataclass
 55class OaKAgent(Generic[ObservationT, ActionT, SubjectiveStateT, InfoT]):
 56    """Coordinates one full OaK step across the four modules.
 57
 58    The agent is a wiring object: you provide concrete implementations of
 59    `Perception`, `TransitionModel`, `ValueFunction`, and
 60    `ReactivePolicy`, and `OaKAgent` ensures they are called in a
 61    consistent order.
 62    """
 63
 64    perception: Perception[ObservationT, ActionT, SubjectiveStateT]
 65    transition_model: TransitionModel[SubjectiveStateT, ActionT, InfoT]
 66    value_function: ValueFunction[SubjectiveStateT, ActionT, InfoT]
 67    reactive_policy: ReactivePolicy[SubjectiveStateT, ActionT, InfoT]
 68
 69    planning_budget: int = 4
 70    feature_budget: int = 4
 71    option_stop_threshold: float = 0.5
 72
 73    last_action: ActionT | None = None
 74    last_subjective_state: SubjectiveStateT | None = None
 75    last_active_option_id: OptionId | None = None
 76
 77    def __init__(
 78        self,
 79        perception: Perception[ObservationT, ActionT, SubjectiveStateT],
 80        transition_model: TransitionModel[SubjectiveStateT, ActionT, InfoT],
 81        value_function: ValueFunction[SubjectiveStateT, ActionT, InfoT],
 82        reactive_policy: ReactivePolicy[SubjectiveStateT, ActionT, InfoT],
 83        planning_budget: int = 4,
 84        feature_budget: int = 4,
 85        option_stop_threshold: float = 0.5,
 86    ):
 87        self.perception = perception
 88        self.transition_model = transition_model
 89        self.value_function = value_function
 90        self.reactive_policy = reactive_policy
 91        self.planning_budget = planning_budget
 92        self.feature_budget = feature_budget
 93        self.option_stop_threshold = option_stop_threshold
 94        self.last_action = None
 95        self.last_subjective_state = None
 96        self.last_active_option_id = None
 97
 98    def __post_init__(self):
 99        """Validate that the modules are compatible."""
100        if self.planning_budget < 1:
101            raise ValueError("Planning budget must be at least 1.")
102        if self.feature_budget < 1:
103            raise ValueError("Feature budget must be at least 1.")
104        if self.option_stop_threshold < 0 or self.option_stop_threshold > 1:
105            raise ValueError("Option stop threshold must be in [0, 1].")
106
107    def reset(self) -> None:
108        """Clear transient execution memory."""
109        self.perception.reset()
110        self.reactive_policy.clear_active_option()
111        self.last_action = None
112        self.last_subjective_state = None
113        self.last_active_option_id = None
114
115    def step(
116        self, time_step: TimeStep[ObservationT, InfoT]
117    ) -> AgentStepResult[ActionT, SubjectiveStateT]:
118        """Run one temporally uniform agent step.
119
120        The step follows six phases: perceive, learn, grow, plan, act, maintain.
121        """
122
123        # ================== 1. Perceive ================= #
124        subjective_state = self.perception.update(
125            observation=time_step.observation,
126            reward=time_step.reward,
127            last_action=self.last_action,
128        )
129
130        created_subtasks: Sequence[SubtaskSpec] = ()
131        ranked_feature_ids: Sequence[FeatureId] = ()
132        planning_update: PlanningUpdate[ActionT] | None = None
133        curation_decision: CurationDecision | None = None
134
135        # ================== 2. Learn ================== #
136        if self.last_subjective_state is not None and self.last_action is not None:
137            transition = Transition(
138                subjective_state=self.last_subjective_state,
139                action=self.last_action,
140                reward=time_step.reward,
141                next_subjective_state=subjective_state,
142                terminated=time_step.terminated or time_step.truncated,
143                option_id=self.last_active_option_id,
144                info=time_step.info,
145            )
146            td_errors = self.value_function.update(transition)
147            self.reactive_policy.update(transition, td_errors)
148            self.transition_model.update(transition)
149
150            # Meta step-size adaptation (Sutton's IDBD / online cross-validation)
151            meta_signals = dict(td_errors)
152            meta_signals["reward"] = transition.reward
153            self.perception.update_meta(meta_signals)
154            self.value_function.update_meta(meta_signals)
155            self.reactive_policy.update_meta(meta_signals)
156            self.transition_model.update_meta(meta_signals)
157
158        # ================== 3. Grow ================== #
159        ranked_feature_ids = self.perception.discover_and_rank_features(
160            subjective_state,
161            self.value_function.utility_scores(),
162            self.feature_budget,
163        )
164        if ranked_feature_ids:
165            created_subtasks = self.perception.generate_subtasks(ranked_feature_ids)
166            if created_subtasks:
167                self.reactive_policy.ingest_subtasks(created_subtasks)
168
169        self.reactive_policy.integrate_options()
170        self.transition_model.integrate_option_models()
171
172        # ================== 4. Plan ================== #
173        planning_update = self.transition_model.plan(
174            subjective_state, self.value_function, self.planning_budget
175        )
176        self.reactive_policy.apply_planning_update(planning_update)
177
178        # ================== 5. Act ================== #
179        action, active_option_id = self.reactive_policy.select_action(
180            subjective_state, self.option_stop_threshold
181        )
182
183        # ================== 6. Maintain ================== #
184        usage_records = self._build_usage_records(ranked_feature_ids, active_option_id)
185        if usage_records:
186            self.value_function.observe_usage(usage_records)
187
188        curation_decision = self.value_function.curate()
189        self._apply_curation(curation_decision)
190
191        # ================== Update Memory ================== #
192        self.last_subjective_state = subjective_state
193        self.last_action = action
194        self.last_active_option_id = active_option_id
195
196        if time_step.terminated or time_step.truncated:
197            self.reactive_policy.clear_active_option()
198
199        return AgentStepResult(
200            action=action,
201            subjective_state=subjective_state,
202            active_option_id=active_option_id,
203            planning_update=planning_update,
204            created_subtasks=created_subtasks,
205            curation_decision=curation_decision,
206        )
207
208    # ── training loop ─────────────────────────────────────────────────
209
210    def train(
211        self,
212        world: World[ObservationT, ActionT, InfoT],
213        *,  # enforce keyword arguments after for clarity
214        num_episodes: int = 500,
215        average_window: int = 100,
216        solved_threshold: float | None = None,
217        episode_logger: Callable[[int, float, float, Self], None] | None = None,
218        episode_trace_logger: Callable[
219            [EpisodeTrace[ObservationT, ActionT, SubjectiveStateT, InfoT]], None
220        ]
221        | None = None,
222        trace_selector: Callable[[int, int], bool] | None = None,
223        capture_rendered_frames: bool = False,
224    ) -> list[float]:
225        """Run the standard OaK episode loop on the given world.
226
227        Parameters
228        ----------
229        world:
230            An environment implementing the `World` protocol.
231        num_episodes:
232            Maximum number of training episodes.
233        average_window:
234            Number of recent episodes to average for performance tracking.
235        solved_threshold:
236            If set, stop early when the average reward over the last `average_window`
237            episodes reaches this value.
238        episode_logger:
239            Optional callback `(episode, episode_reward, avg_reward, agent)`
240            called after each episode. Use this to own all per-episode logging
241            or other training-side effects at the call site.
242        episode_trace_logger:
243            Optional richer callback receiving an `EpisodeTrace` with the
244            training world, agent, selected step records, and optionally
245            rendered frames.
246        trace_selector:
247            Optional predicate `(episode, num_episodes) -> bool` used to decide
248            which episodes should produce an `EpisodeTrace`. If omitted and an
249            `episode_trace_logger` is provided, all episodes are traced.
250        capture_rendered_frames:
251            When `True`, collect frames from worlds that implement the optional
252            `render_frame()` capability for traced episodes.
253
254        Returns
255        -------
256        list[float]
257            Per-episode reward history.
258        """
259        if average_window < 1:
260            raise ValueError("average_window must be at least 1.")
261
262        reward_history: list[float] = []
263
264        for episode in range(num_episodes):
265            time_step = world.reset()
266            initial_time_step = time_step
267            self.reset()
268            episode_reward = 0.0
269            step_count = 0
270            capture_trace = episode_trace_logger is not None and (
271                trace_selector(episode, num_episodes) if trace_selector is not None else True
272            )
273            traced_steps: list[
274                EpisodeStepRecord[ObservationT, ActionT, SubjectiveStateT, InfoT]
275            ] = []
276            traced_frames: list[object] = []
277
278            if capture_trace and capture_rendered_frames:
279                frame = self._render_frame(world)
280                if frame is not None:
281                    traced_frames.append(frame)
282
283            while True:
284                result = self.step(time_step)
285
286                if time_step.terminated or time_step.truncated:
287                    break
288
289                next_time_step = world.step(result.action)
290                episode_reward += next_time_step.reward
291                if capture_trace:
292                    traced_steps.append(
293                        EpisodeStepRecord(
294                            step_index=step_count,
295                            time_step=time_step,
296                            action=result.action,
297                            next_time_step=next_time_step,
298                            active_option_id=result.active_option_id,
299                            planning_update=result.planning_update,
300                            created_subtasks=result.created_subtasks,
301                        )
302                    )
303                    if capture_rendered_frames:
304                        frame = self._render_frame(world)
305                        if frame is not None:
306                            traced_frames.append(frame)
307                time_step = next_time_step
308                step_count += 1
309
310            reward_history.append(episode_reward)
311            recent_window = reward_history[-average_window:]
312            avg_reward = sum(recent_window) / len(recent_window)
313
314            solved = (
315                solved_threshold is not None
316                and len(reward_history) >= average_window
317                and avg_reward >= solved_threshold
318            )
319
320            if episode_logger is not None:
321                episode_logger(episode, episode_reward, avg_reward, self)
322
323            if episode_trace_logger is not None and capture_trace:
324                episode_trace_logger(
325                    EpisodeTrace(
326                        episode=episode,
327                        episode_reward=episode_reward,
328                        avg_reward=avg_reward,
329                        step_count=step_count,
330                        solved=solved,
331                        initial_time_step=initial_time_step,
332                        final_time_step=time_step,
333                        steps=tuple(traced_steps),
334                        frames=tuple(traced_frames),
335                        world=world,
336                        agent=self,
337                        metadata={
338                            "num_episodes": num_episodes,
339                            "capture_rendered_frames": capture_rendered_frames,
340                        },
341                    )
342                )
343
344            for component in (
345                self.perception,
346                self.value_function,
347                self.transition_model,
348                self.reactive_policy,
349            ):
350                end_episode = getattr(component, "end_episode", None)
351                if callable(end_episode):
352                    end_episode()
353
354            if solved:
355                break
356
357        return reward_history
358
359    # ── private helpers ──────────────────────────────────────────────
360
361    def _build_usage_records(
362        self,
363        ranked_feature_ids: Sequence[FeatureId],
364        active_option_id: OptionId | None,
365    ) -> Sequence[UsageRecord]:
366        """Build minimal utility-accounting observations for the current step."""
367        usage_records = [
368            UsageRecord(ComponentKind.FEATURE, feature_id)
369            for feature_id in ranked_feature_ids
370        ]
371        if active_option_id is not None:
372            usage_records.append(UsageRecord(ComponentKind.OPTION, active_option_id))
373        return tuple(usage_records)
374
375    def _apply_curation(self, decision: CurationDecision) -> None:
376        """Dispatch curation decisions to the relevant modules."""
377        if decision.drop_features:
378            self.perception.remove_features(decision.drop_features)
379        if decision.drop_subtasks:
380            self.reactive_policy.remove_subtasks(decision.drop_subtasks)
381        if decision.drop_options:
382            self.reactive_policy.remove_options(decision.drop_options)
383        if decision.drop_option_models:
384            self.transition_model.remove_option_models(decision.drop_option_models)
385        if decision.drop_general_value_functions:
386            self.value_function.remove(decision.drop_general_value_functions)
387
388    def _render_frame(
389        self,
390        world: World[ObservationT, ActionT, InfoT],
391    ) -> object | None:
392        """Capture one world frame when optional rendering is available."""
393        if not isinstance(world, RenderableWorld):
394            return None
395        return world.render_frame()

Coordinates one full OaK step across the four modules.

The agent is a wiring object: you provide concrete implementations of Perception, TransitionModel, ValueFunction, and ReactivePolicy, and OaKAgent ensures they are called in a consistent order.

OaKAgent( perception: 'Perception[ObservationT, ActionT, SubjectiveStateT]', transition_model: 'TransitionModel[SubjectiveStateT, ActionT, InfoT]', value_function: 'ValueFunction[SubjectiveStateT, ActionT, InfoT]', reactive_policy: 'ReactivePolicy[SubjectiveStateT, ActionT, InfoT]', planning_budget: 'int' = 4, feature_budget: 'int' = 4, option_stop_threshold: 'float' = 0.5)
77    def __init__(
78        self,
79        perception: Perception[ObservationT, ActionT, SubjectiveStateT],
80        transition_model: TransitionModel[SubjectiveStateT, ActionT, InfoT],
81        value_function: ValueFunction[SubjectiveStateT, ActionT, InfoT],
82        reactive_policy: ReactivePolicy[SubjectiveStateT, ActionT, InfoT],
83        planning_budget: int = 4,
84        feature_budget: int = 4,
85        option_stop_threshold: float = 0.5,
86    ):
87        self.perception = perception
88        self.transition_model = transition_model
89        self.value_function = value_function
90        self.reactive_policy = reactive_policy
91        self.planning_budget = planning_budget
92        self.feature_budget = feature_budget
93        self.option_stop_threshold = option_stop_threshold
94        self.last_action = None
95        self.last_subjective_state = None
96        self.last_active_option_id = None
perception: 'Perception[ObservationT, ActionT, SubjectiveStateT]'
transition_model: 'TransitionModel[SubjectiveStateT, ActionT, InfoT]'
value_function: 'ValueFunction[SubjectiveStateT, ActionT, InfoT]'
reactive_policy: 'ReactivePolicy[SubjectiveStateT, ActionT, InfoT]'
planning_budget: 'int' = 4
feature_budget: 'int' = 4
option_stop_threshold: 'float' = 0.5
last_action: 'ActionT | None' = None
last_subjective_state: 'SubjectiveStateT | None' = None
last_active_option_id: 'OptionId | None' = None
def reset(self) -> 'None':
107    def reset(self) -> None:
108        """Clear transient execution memory."""
109        self.perception.reset()
110        self.reactive_policy.clear_active_option()
111        self.last_action = None
112        self.last_subjective_state = None
113        self.last_active_option_id = None

Clear transient execution memory.

def step( self, time_step: 'TimeStep[ObservationT, InfoT]') -> 'AgentStepResult[ActionT, SubjectiveStateT]':
115    def step(
116        self, time_step: TimeStep[ObservationT, InfoT]
117    ) -> AgentStepResult[ActionT, SubjectiveStateT]:
118        """Run one temporally uniform agent step.
119
120        The step follows six phases: perceive, learn, grow, plan, act, maintain.
121        """
122
123        # ================== 1. Perceive ================= #
124        subjective_state = self.perception.update(
125            observation=time_step.observation,
126            reward=time_step.reward,
127            last_action=self.last_action,
128        )
129
130        created_subtasks: Sequence[SubtaskSpec] = ()
131        ranked_feature_ids: Sequence[FeatureId] = ()
132        planning_update: PlanningUpdate[ActionT] | None = None
133        curation_decision: CurationDecision | None = None
134
135        # ================== 2. Learn ================== #
136        if self.last_subjective_state is not None and self.last_action is not None:
137            transition = Transition(
138                subjective_state=self.last_subjective_state,
139                action=self.last_action,
140                reward=time_step.reward,
141                next_subjective_state=subjective_state,
142                terminated=time_step.terminated or time_step.truncated,
143                option_id=self.last_active_option_id,
144                info=time_step.info,
145            )
146            td_errors = self.value_function.update(transition)
147            self.reactive_policy.update(transition, td_errors)
148            self.transition_model.update(transition)
149
150            # Meta step-size adaptation (Sutton's IDBD / online cross-validation)
151            meta_signals = dict(td_errors)
152            meta_signals["reward"] = transition.reward
153            self.perception.update_meta(meta_signals)
154            self.value_function.update_meta(meta_signals)
155            self.reactive_policy.update_meta(meta_signals)
156            self.transition_model.update_meta(meta_signals)
157
158        # ================== 3. Grow ================== #
159        ranked_feature_ids = self.perception.discover_and_rank_features(
160            subjective_state,
161            self.value_function.utility_scores(),
162            self.feature_budget,
163        )
164        if ranked_feature_ids:
165            created_subtasks = self.perception.generate_subtasks(ranked_feature_ids)
166            if created_subtasks:
167                self.reactive_policy.ingest_subtasks(created_subtasks)
168
169        self.reactive_policy.integrate_options()
170        self.transition_model.integrate_option_models()
171
172        # ================== 4. Plan ================== #
173        planning_update = self.transition_model.plan(
174            subjective_state, self.value_function, self.planning_budget
175        )
176        self.reactive_policy.apply_planning_update(planning_update)
177
178        # ================== 5. Act ================== #
179        action, active_option_id = self.reactive_policy.select_action(
180            subjective_state, self.option_stop_threshold
181        )
182
183        # ================== 6. Maintain ================== #
184        usage_records = self._build_usage_records(ranked_feature_ids, active_option_id)
185        if usage_records:
186            self.value_function.observe_usage(usage_records)
187
188        curation_decision = self.value_function.curate()
189        self._apply_curation(curation_decision)
190
191        # ================== Update Memory ================== #
192        self.last_subjective_state = subjective_state
193        self.last_action = action
194        self.last_active_option_id = active_option_id
195
196        if time_step.terminated or time_step.truncated:
197            self.reactive_policy.clear_active_option()
198
199        return AgentStepResult(
200            action=action,
201            subjective_state=subjective_state,
202            active_option_id=active_option_id,
203            planning_update=planning_update,
204            created_subtasks=created_subtasks,
205            curation_decision=curation_decision,
206        )

Run one temporally uniform agent step.

The step follows six phases: perceive, learn, grow, plan, act, maintain.

def train( self, world: 'World[ObservationT, ActionT, InfoT]', *, num_episodes: 'int' = 500, average_window: 'int' = 100, solved_threshold: 'float | None' = None, episode_logger: 'Callable[[int, float, float, Self], None] | None' = None, episode_trace_logger: 'Callable[[EpisodeTrace[ObservationT, ActionT, SubjectiveStateT, InfoT]], None] | None' = None, trace_selector: 'Callable[[int, int], bool] | None' = None, capture_rendered_frames: 'bool' = False) -> 'list[float]':
210    def train(
211        self,
212        world: World[ObservationT, ActionT, InfoT],
213        *,  # enforce keyword arguments after for clarity
214        num_episodes: int = 500,
215        average_window: int = 100,
216        solved_threshold: float | None = None,
217        episode_logger: Callable[[int, float, float, Self], None] | None = None,
218        episode_trace_logger: Callable[
219            [EpisodeTrace[ObservationT, ActionT, SubjectiveStateT, InfoT]], None
220        ]
221        | None = None,
222        trace_selector: Callable[[int, int], bool] | None = None,
223        capture_rendered_frames: bool = False,
224    ) -> list[float]:
225        """Run the standard OaK episode loop on the given world.
226
227        Parameters
228        ----------
229        world:
230            An environment implementing the `World` protocol.
231        num_episodes:
232            Maximum number of training episodes.
233        average_window:
234            Number of recent episodes to average for performance tracking.
235        solved_threshold:
236            If set, stop early when the average reward over the last `average_window`
237            episodes reaches this value.
238        episode_logger:
239            Optional callback `(episode, episode_reward, avg_reward, agent)`
240            called after each episode. Use this to own all per-episode logging
241            or other training-side effects at the call site.
242        episode_trace_logger:
243            Optional richer callback receiving an `EpisodeTrace` with the
244            training world, agent, selected step records, and optionally
245            rendered frames.
246        trace_selector:
247            Optional predicate `(episode, num_episodes) -> bool` used to decide
248            which episodes should produce an `EpisodeTrace`. If omitted and an
249            `episode_trace_logger` is provided, all episodes are traced.
250        capture_rendered_frames:
251            When `True`, collect frames from worlds that implement the optional
252            `render_frame()` capability for traced episodes.
253
254        Returns
255        -------
256        list[float]
257            Per-episode reward history.
258        """
259        if average_window < 1:
260            raise ValueError("average_window must be at least 1.")
261
262        reward_history: list[float] = []
263
264        for episode in range(num_episodes):
265            time_step = world.reset()
266            initial_time_step = time_step
267            self.reset()
268            episode_reward = 0.0
269            step_count = 0
270            capture_trace = episode_trace_logger is not None and (
271                trace_selector(episode, num_episodes) if trace_selector is not None else True
272            )
273            traced_steps: list[
274                EpisodeStepRecord[ObservationT, ActionT, SubjectiveStateT, InfoT]
275            ] = []
276            traced_frames: list[object] = []
277
278            if capture_trace and capture_rendered_frames:
279                frame = self._render_frame(world)
280                if frame is not None:
281                    traced_frames.append(frame)
282
283            while True:
284                result = self.step(time_step)
285
286                if time_step.terminated or time_step.truncated:
287                    break
288
289                next_time_step = world.step(result.action)
290                episode_reward += next_time_step.reward
291                if capture_trace:
292                    traced_steps.append(
293                        EpisodeStepRecord(
294                            step_index=step_count,
295                            time_step=time_step,
296                            action=result.action,
297                            next_time_step=next_time_step,
298                            active_option_id=result.active_option_id,
299                            planning_update=result.planning_update,
300                            created_subtasks=result.created_subtasks,
301                        )
302                    )
303                    if capture_rendered_frames:
304                        frame = self._render_frame(world)
305                        if frame is not None:
306                            traced_frames.append(frame)
307                time_step = next_time_step
308                step_count += 1
309
310            reward_history.append(episode_reward)
311            recent_window = reward_history[-average_window:]
312            avg_reward = sum(recent_window) / len(recent_window)
313
314            solved = (
315                solved_threshold is not None
316                and len(reward_history) >= average_window
317                and avg_reward >= solved_threshold
318            )
319
320            if episode_logger is not None:
321                episode_logger(episode, episode_reward, avg_reward, self)
322
323            if episode_trace_logger is not None and capture_trace:
324                episode_trace_logger(
325                    EpisodeTrace(
326                        episode=episode,
327                        episode_reward=episode_reward,
328                        avg_reward=avg_reward,
329                        step_count=step_count,
330                        solved=solved,
331                        initial_time_step=initial_time_step,
332                        final_time_step=time_step,
333                        steps=tuple(traced_steps),
334                        frames=tuple(traced_frames),
335                        world=world,
336                        agent=self,
337                        metadata={
338                            "num_episodes": num_episodes,
339                            "capture_rendered_frames": capture_rendered_frames,
340                        },
341                    )
342                )
343
344            for component in (
345                self.perception,
346                self.value_function,
347                self.transition_model,
348                self.reactive_policy,
349            ):
350                end_episode = getattr(component, "end_episode", None)
351                if callable(end_episode):
352                    end_episode()
353
354            if solved:
355                break
356
357        return reward_history

Run the standard OaK episode loop on the given world.

Parameters

world: An environment implementing the World protocol. num_episodes: Maximum number of training episodes. average_window: Number of recent episodes to average for performance tracking. solved_threshold: If set, stop early when the average reward over the last average_window episodes reaches this value. episode_logger: Optional callback (episode, episode_reward, avg_reward, agent) called after each episode. Use this to own all per-episode logging or other training-side effects at the call site. episode_trace_logger: Optional richer callback receiving an EpisodeTrace with the training world, agent, selected step records, and optionally rendered frames. trace_selector: Optional predicate (episode, num_episodes) -> bool used to decide which episodes should produce an EpisodeTrace. If omitted and an episode_trace_logger is provided, all episodes are traced. capture_rendered_frames: When True, collect frames from worlds that implement the optional render_frame() capability for traced episodes.

Returns

list[float] Per-episode reward history.

@runtime_checkable
class World(typing.Protocol[~WorldObservationT, -WorldActionT, ~WorldInfoT]):
64@runtime_checkable
65class World(Protocol[WorldObservationT, WorldActionT, WorldInfoT]):
66    """Minimal environment protocol.
67
68    A `World` may wrap a simulator, a benchmark environment, or a custom
69    continual data source.  The protocol is intentionally small so the
70    package does not depend on a specific environment library.
71
72    Implement this protocol for any environment you want to use with
73    `OaKAgent.train()`.
74    """
75
76    def reset(self) -> TimeStep[WorldObservationT, WorldInfoT]: ...
77
78    def step(
79        self, action: WorldActionT
80    ) -> TimeStep[WorldObservationT, WorldInfoT]: ...
81
82    def close(self) -> None:
83        """Release environment resources.  Default is a no-op."""
84        ...

Minimal environment protocol.

A World may wrap a simulator, a benchmark environment, or a custom continual data source. The protocol is intentionally small so the package does not depend on a specific environment library.

Implement this protocol for any environment you want to use with OaKAgent.train().

World(*args, **kwargs)
1957def _no_init_or_replace_init(self, *args, **kwargs):
1958    cls = type(self)
1959
1960    if cls._is_protocol:
1961        raise TypeError('Protocols cannot be instantiated')
1962
1963    # Already using a custom `__init__`. No need to calculate correct
1964    # `__init__` to call. This can lead to RecursionError. See bpo-45121.
1965    if cls.__init__ is not _no_init_or_replace_init:
1966        return
1967
1968    # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`.
1969    # The first instantiation of the subclass will call `_no_init_or_replace_init` which
1970    # searches for a proper new `__init__` in the MRO. The new `__init__`
1971    # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent
1972    # instantiation of the protocol subclass will thus use the new
1973    # `__init__` and no longer call `_no_init_or_replace_init`.
1974    for base in cls.__mro__:
1975        init = base.__dict__.get('__init__', _no_init_or_replace_init)
1976        if init is not _no_init_or_replace_init:
1977            cls.__init__ = init
1978            break
1979    else:
1980        # should not happen
1981        cls.__init__ = object.__init__
1982
1983    cls.__init__(self, *args, **kwargs)
def reset(self) -> 'TimeStep[WorldObservationT, WorldInfoT]':
76    def reset(self) -> TimeStep[WorldObservationT, WorldInfoT]: ...
def step( self, action: 'WorldActionT') -> 'TimeStep[WorldObservationT, WorldInfoT]':
78    def step(
79        self, action: WorldActionT
80    ) -> TimeStep[WorldObservationT, WorldInfoT]: ...
def close(self) -> 'None':
82    def close(self) -> None:
83        """Release environment resources.  Default is a no-op."""
84        ...

Release environment resources. Default is a no-op.

@dataclass(slots=True, frozen=True)
class AgentStepResult(typing.Generic[~ActionT, ~SubjectiveStateT]):
246@dataclass(slots=True, frozen=True)
247class AgentStepResult(Generic[ActionT, SubjectiveStateT]):
248    """Observable result of one OaK agent step.
249
250    This is the compact object a caller receives after stepping the agent. It
251    includes the primitive action actually executed, the current subjective
252    state, and any structures or planning signals created during that step.
253    """
254
255    action: ActionT
256    subjective_state: SubjectiveStateT
257    active_option_id: OptionId | None = None
258    planning_update: PlanningUpdate[ActionT] | None = None
259    created_subtasks: Sequence[SubtaskSpec] = field(default_factory=tuple)
260    curation_decision: CurationDecision | None = None

Observable result of one OaK agent step.

This is the compact object a caller receives after stepping the agent. It includes the primitive action actually executed, the current subjective state, and any structures or planning signals created during that step.

AgentStepResult( action: 'ActionT', subjective_state: 'SubjectiveStateT', active_option_id: 'OptionId | None' = None, planning_update: 'PlanningUpdate[ActionT] | None' = None, created_subtasks: 'Sequence[SubtaskSpec]' = <factory>, curation_decision: 'CurationDecision | None' = None)
action: 'ActionT'
subjective_state: 'SubjectiveStateT'
active_option_id: 'OptionId | None'
planning_update: 'PlanningUpdate[ActionT] | None'
created_subtasks: 'Sequence[SubtaskSpec]'
curation_decision: 'CurationDecision | None'
class ComponentKind(builtins.str, enum.Enum):
60class ComponentKind(str, Enum):
61    """Kinds of learnable or managed elements in the architecture."""
62
63    FEATURE = "feature"
64    SUBTASK = "subtask"
65    OPTION = "option"
66    VALUE_FUNCTION = "value_function"
67    OPTION_MODEL = "option_model"
68    TRANSITION_MODEL = "transition_model"
69    POLICY = "policy"
70    PERCEPTION = "perception"
71    PLANNER = "planner"

Kinds of learnable or managed elements in the architecture.

FEATURE = <ComponentKind.FEATURE: 'feature'>
SUBTASK = <ComponentKind.SUBTASK: 'subtask'>
OPTION = <ComponentKind.OPTION: 'option'>
VALUE_FUNCTION = <ComponentKind.VALUE_FUNCTION: 'value_function'>
OPTION_MODEL = <ComponentKind.OPTION_MODEL: 'option_model'>
TRANSITION_MODEL = <ComponentKind.TRANSITION_MODEL: 'transition_model'>
POLICY = <ComponentKind.POLICY: 'policy'>
PERCEPTION = <ComponentKind.PERCEPTION: 'perception'>
PLANNER = <ComponentKind.PLANNER: 'planner'>
@dataclass(slots=True, frozen=True)
class CurationDecision:
232@dataclass(slots=True, frozen=True)
233class CurationDecision:
234    """Pruning decision returned by the curator."""
235
236    drop_features: Sequence[FeatureId] = field(default_factory=tuple)
237    drop_subtasks: Sequence[SubtaskId] = field(default_factory=tuple)
238    drop_options: Sequence[OptionId] = field(default_factory=tuple)
239    drop_option_models: Sequence[OptionId] = field(default_factory=tuple)
240    drop_general_value_functions: Sequence[GeneralValueFunctionId] = field(
241        default_factory=tuple
242    )
243    notes: StructuredPayload = field(default_factory=dict)

Pruning decision returned by the curator.

CurationDecision( drop_features: 'Sequence[FeatureId]' = <factory>, drop_subtasks: 'Sequence[SubtaskId]' = <factory>, drop_options: 'Sequence[OptionId]' = <factory>, drop_option_models: 'Sequence[OptionId]' = <factory>, drop_general_value_functions: 'Sequence[GeneralValueFunctionId]' = <factory>, notes: 'StructuredPayload' = <factory>)
drop_features: 'Sequence[FeatureId]'
drop_subtasks: 'Sequence[SubtaskId]'
drop_options: 'Sequence[OptionId]'
drop_option_models: 'Sequence[OptionId]'
drop_general_value_functions: 'Sequence[GeneralValueFunctionId]'
notes: 'StructuredPayload'
@dataclass(slots=True, frozen=True)
class FeatureCandidate:
130@dataclass(slots=True, frozen=True)
131class FeatureCandidate:
132    """A proposed feature that may be admitted into the feature bank."""
133
134    feature_id: FeatureId
135    name: str
136    origin: str
137    description: str = ""
138    metadata: OpenPayload = field(default_factory=dict)

A proposed feature that may be admitted into the feature bank.

FeatureCandidate( feature_id: 'FeatureId', name: 'str', origin: 'str', description: 'str' = '', metadata: 'OpenPayload' = <factory>)
feature_id: 'FeatureId'
name: 'str'
origin: 'str'
description: 'str'
metadata: 'OpenPayload'
@dataclass(slots=True, frozen=True)
class FeatureSpec:
120@dataclass(slots=True, frozen=True)
121class FeatureSpec:
122    """Metadata describing a feature tracked by the agent."""
123
124    feature_id: FeatureId
125    name: str
126    description: str = ""
127    metadata: OpenPayload = field(default_factory=dict)

Metadata describing a feature tracked by the agent.

FeatureSpec( feature_id: 'FeatureId', name: 'str', description: 'str' = '', metadata: 'OpenPayload' = <factory>)
feature_id: 'FeatureId'
name: 'str'
description: 'str'
metadata: 'OpenPayload'
@dataclass(slots=True, frozen=True)
class GeneralValueFunctionSpec(typing.Generic[~ActionT, ~SubjectiveStateT, ~InfoT]):
141@dataclass(slots=True, frozen=True)
142class GeneralValueFunctionSpec(Generic[ActionT, SubjectiveStateT, InfoT]):
143    """General value function specification."""
144
145    general_value_function_id: GeneralValueFunctionId
146    name: str
147    cumulant: ScalarSignal
148    continuation: ContinuationFn
149    termination_value: TerminationValueFn
150    metadata: OpenPayload = field(default_factory=dict)

General value function specification.

GeneralValueFunctionSpec( general_value_function_id: 'GeneralValueFunctionId', name: 'str', cumulant: 'ScalarSignal', continuation: 'ContinuationFn', termination_value: 'TerminationValueFn', metadata: 'OpenPayload' = <factory>)
general_value_function_id: 'GeneralValueFunctionId'
name: 'str'
cumulant: 'ScalarSignal'
continuation: 'ContinuationFn'
termination_value: 'TerminationValueFn'
metadata: 'OpenPayload'
@dataclass(slots=True, frozen=True)
class ModelPrediction(typing.Generic[~SubjectiveStateT]):
192@dataclass(slots=True, frozen=True)
193class ModelPrediction(Generic[SubjectiveStateT]):
194    """Prediction returned by an action or option model."""
195
196    predicted_subjective_state: SubjectiveStateT
197    cumulative_reward: float
198    steps: int | None = None
199    terminated: bool = False
200    metadata: OpenPayload = field(default_factory=dict)

Prediction returned by an action or option model.

ModelPrediction( predicted_subjective_state: 'SubjectiveStateT', cumulative_reward: 'float', steps: 'int | None' = None, terminated: 'bool' = False, metadata: 'OpenPayload' = <factory>)
predicted_subjective_state: 'SubjectiveStateT'
cumulative_reward: 'float'
steps: 'int | None'
terminated: 'bool'
metadata: 'OpenPayload'
@dataclass(slots=True, frozen=True)
class OptionDescriptor:
165@dataclass(slots=True, frozen=True)
166class OptionDescriptor:
167    """Lightweight metadata for an option."""
168
169    option_id: OptionId
170    name: str
171    subtask_id: SubtaskId | None = None
172    metadata: OpenPayload = field(default_factory=dict)

Lightweight metadata for an option.

OptionDescriptor( option_id: 'OptionId', name: 'str', subtask_id: 'SubtaskId | None' = None, metadata: 'OpenPayload' = <factory>)
option_id: 'OptionId'
name: 'str'
subtask_id: 'SubtaskId | None'
metadata: 'OpenPayload'
@dataclass(slots=True, frozen=True)
class PlanningUpdate(typing.Generic[~ActionT]):
203@dataclass(slots=True, frozen=True)
204class PlanningUpdate(Generic[ActionT]):
205    """Outputs from one planning pass."""
206
207    value_targets: Mapping[GeneralValueFunctionId, float] = field(default_factory=dict)
208    policy_targets: StructuredPayload = field(default_factory=dict)
209    search_statistics: StructuredPayload = field(default_factory=dict)

Outputs from one planning pass.

PlanningUpdate( value_targets: 'Mapping[GeneralValueFunctionId, float]' = <factory>, policy_targets: 'StructuredPayload' = <factory>, search_statistics: 'StructuredPayload' = <factory>)
value_targets: 'Mapping[GeneralValueFunctionId, float]'
policy_targets: 'StructuredPayload'
search_statistics: 'StructuredPayload'
@dataclass(slots=True, frozen=True)
class PolicyDecision(typing.Generic[~ActionT]):
175@dataclass(slots=True, frozen=True)
176class PolicyDecision(Generic[ActionT]):
177    """Return type for reactive policy selection."""
178
179    action: ActionT | None = None
180    option_id: OptionId | None = None
181    metadata: OpenPayload = field(default_factory=dict)
182
183    def __post_init__(self) -> None:
184        has_action = self.action is not None
185        has_option = self.option_id is not None
186        if has_action == has_option:
187            raise ValueError(
188                "PolicyDecision requires exactly one of action or option_id."
189            )

Return type for reactive policy selection.

PolicyDecision( action: 'ActionT | None' = None, option_id: 'OptionId | None' = None, metadata: 'OpenPayload' = <factory>)
action: 'ActionT | None'
option_id: 'OptionId | None'
metadata: 'OpenPayload'
@dataclass(slots=True, frozen=True)
class SubtaskSpec:
153@dataclass(slots=True, frozen=True)
154class SubtaskSpec:
155    """A feature-grounded subtask description."""
156
157    subtask_id: SubtaskId
158    name: str
159    feature_id: FeatureId
160    intensity: float = 1.0
161    general_value_function_id: GeneralValueFunctionId | None = None
162    metadata: OpenPayload = field(default_factory=dict)

A feature-grounded subtask description.

SubtaskSpec( subtask_id: 'SubtaskId', name: 'str', feature_id: 'FeatureId', intensity: 'float' = 1.0, general_value_function_id: 'GeneralValueFunctionId | None' = None, metadata: 'OpenPayload' = <factory>)
subtask_id: 'SubtaskId'
name: 'str'
feature_id: 'FeatureId'
intensity: 'float'
general_value_function_id: 'GeneralValueFunctionId | None'
metadata: 'OpenPayload'
@dataclass(slots=True, frozen=True)
class TimeStep(typing.Generic[~ObservationT, ~InfoT]):
74@dataclass(slots=True, frozen=True)
75class TimeStep(Generic[ObservationT, InfoT]):
76    """One environment emission seen by the agent.
77
78    `TimeStep` is the object passed into `OaKAgent.step(...)`. It contains the
79    raw observation, scalar reward, episode-control flags, and optional
80    environment metadata.
81    """
82
83    observation: ObservationT
84    reward: float
85    terminated: bool = False
86    truncated: bool = False
87    info: InfoT | None = None

One environment emission seen by the agent.

TimeStep is the object passed into OaKAgent.step(...). It contains the raw observation, scalar reward, episode-control flags, and optional environment metadata.

TimeStep( observation: 'ObservationT', reward: 'float', terminated: 'bool' = False, truncated: 'bool' = False, info: 'InfoT | None' = None)
observation: 'ObservationT'
reward: 'float'
terminated: 'bool'
truncated: 'bool'
info: 'InfoT | None'
@dataclass(slots=True, frozen=True)
class Transition(typing.Generic[~ActionT, ~SubjectiveStateT, ~InfoT]):
 90@dataclass(slots=True, frozen=True)
 91class Transition(Generic[ActionT, SubjectiveStateT, InfoT]):
 92    """One subjective-state transition in agent terms.
 93
 94    `Transition` is constructed by the agent after two consecutive time steps.
 95    Learners use it instead of the raw world stream so they can access both the
 96    previous and next subjective state representations together with reward,
 97    termination, and optional environment metadata.
 98    """
 99
100    subjective_state: SubjectiveStateT
101    action: ActionT
102    reward: float
103    next_subjective_state: SubjectiveStateT
104    terminated: bool = False
105    option_id: OptionId | None = None
106    info: InfoT | None = None

One subjective-state transition in agent terms.

Transition is constructed by the agent after two consecutive time steps. Learners use it instead of the raw world stream so they can access both the previous and next subjective state representations together with reward, termination, and optional environment metadata.

Transition( subjective_state: 'SubjectiveStateT', action: 'ActionT', reward: 'float', next_subjective_state: 'SubjectiveStateT', terminated: 'bool' = False, option_id: 'OptionId | None' = None, info: 'InfoT | None' = None)
subjective_state: 'SubjectiveStateT'
action: 'ActionT'
reward: 'float'
next_subjective_state: 'SubjectiveStateT'
terminated: 'bool'
option_id: 'OptionId | None'
info: 'InfoT | None'
@dataclass(slots=True, frozen=True)
class UsageRecord:
212@dataclass(slots=True, frozen=True)
213class UsageRecord:
214    """Usage evidence gathered for utility assessment."""
215
216    kind: ComponentKind
217    component_id: ComponentId
218    amount: float = 1.0
219    metadata: OpenPayload = field(default_factory=dict)

Usage evidence gathered for utility assessment.

UsageRecord( kind: 'ComponentKind', component_id: 'ComponentId', amount: 'float' = 1.0, metadata: 'OpenPayload' = <factory>)
kind: 'ComponentKind'
component_id: 'ComponentId'
amount: 'float'
metadata: 'OpenPayload'
@dataclass(slots=True, frozen=True)
class UtilityRecord:
222@dataclass(slots=True, frozen=True)
223class UtilityRecord:
224    """Utility score for one architectural element."""
225
226    kind: ComponentKind
227    component_id: ComponentId
228    utility: float
229    evidence: StructuredPayload = field(default_factory=dict)

Utility score for one architectural element.

UtilityRecord( kind: 'ComponentKind', component_id: 'ComponentId', utility: 'float', evidence: 'StructuredPayload' = <factory>)
kind: 'ComponentKind'
component_id: 'ComponentId'
utility: 'float'
evidence: 'StructuredPayload'