oak
1from .agent import OaKAgent 2from . import fine_grained 3from .interfaces import ( 4 ContinualLearner, 5 Perception, 6 ReactivePolicy, 7 TransitionModel, 8 ValueFunction, 9 World, 10) 11from .types import ( 12 AgentStepResult, 13 ComponentKind, 14 CurationDecision, 15 FeatureCandidate, 16 FeatureSpec, 17 GeneralValueFunctionSpec, 18 ModelPrediction, 19 OptionDescriptor, 20 PlanningUpdate, 21 PolicyDecision, 22 SubtaskSpec, 23 TimeStep, 24 Transition, 25 UsageRecord, 26 UtilityRecord, 27) 28 29__all__ = [ 30 # ── Continual-learning mixin ── 31 "ContinualLearner", 32 # ── The four main OaK interfaces ── 33 "Perception", 34 "TransitionModel", 35 "ValueFunction", 36 "ReactivePolicy", 37 # ── Agent ── 38 "OaKAgent", 39 # ── Environment ── 40 "World", 41 # ── Optional advanced assembly layer ── 42 "fine_grained", 43 # ── Shared types ── 44 "AgentStepResult", 45 "ComponentKind", 46 "CurationDecision", 47 "FeatureCandidate", 48 "FeatureSpec", 49 "GeneralValueFunctionSpec", 50 "ModelPrediction", 51 "OptionDescriptor", 52 "PlanningUpdate", 53 "PolicyDecision", 54 "SubtaskSpec", 55 "TimeStep", 56 "Transition", 57 "UsageRecord", 58 "UtilityRecord", 59]
Architecture Guide
Diagram Gallery
OaKAgent -> Composite* -> fine_grained interface used during that step.What You Must Implement
OaKAgent is the canonical coordinator. It is composed of exactly four
objects, one per Sutton module:
perceptionImplementsPerception. It receives raw environment data and must return the currentsubjective_state. It also manages feature discovery, ranking, and subtask generation.transition_modelImplementsTransitionModel. It learns from transitions, maintains option models, and runs bounded planning using the world model and value function.value_functionImplementsValueFunction. It learns fromTransitionobjects, predicts values, tracks utility of learned structures, and produces curation decisions.reactive_policyImplementsReactivePolicy. It selects actions (primitive or options), manages the option library and option learning, and integrates planning updates.
You also configure scalar controls:
planning_budgetfeature_budgetoption_stop_threshold
OaKAgent manages these runtime fields itself:
last_actionlast_subjective_state
Your environment must implement the World protocol (reset, step,
close) to use OaKAgent.train(). You can also drive the loop yourself
by supplying TimeStep objects to OaKAgent.step(...) directly.
Two Ways to Implement
Direct approach: implement the four main interfaces directly. Each of
your classes is a self-contained module. This is the simplest path and what
the examples/smoke/minimal_oak.py example demonstrates.
Composite approach: use the fine-grained component interfaces from
oak.fine_grained.components and wire them together using the
composites from oak.fine_grained.composites. This is for
projects that need to independently swap building blocks inside a module
(e.g. replace the planner without touching the world model). The
examples/smoke/minimal_oak_fine_grained.py example demonstrates this path with
the same toy behavior as the direct example.
| Main interface | Composite class | Fine-grained building blocks |
|---|---|---|
Perception |
CompositePerception |
StateBuilder, FeatureBank, FeatureConstructor, FeatureRanker, SubtaskGenerator |
TransitionModel |
CompositeTransitionModel |
WorldModel, OptionModelLearner, OptionModel, Planner |
ValueFunction |
CompositeValueFunction |
ValueEstimator, GeneralValueFunctionLearner, UtilityAssessor, Curator, MetaStepSizeLearner |
ReactivePolicy |
CompositeReactivePolicy |
ActionSelector, Option, OptionLibrary, OptionLearner, OptionKeyboard (optional) |
Diagram-to-Code Mapping
The diagrams have different jobs, but they all describe the same implementation:
oak_coreThe default conceptual slot map: OaKAgent plus the four main interfaces and the main data flow between them.oak_architectureThe fine-grained slot map: Composite modules, their delegated interfaces, and associated optional interfaces fromoak.fine_grained.components.oak_runtime_overviewThe top-level phase-by-phase sequence at the four-interface layer.oak_runtime_sequenceThe composite-wired per-step call order, showing only the fine-grained interfaces actually touched during onestep(...).
Recommended reading order for the diagrams:
- Read
oak_coreto understand the default four-interface surface. - Read
oak_runtime_overviewfor the six phases ofstep(...). - Read
oak_architectureto see how the optional fine-grained layer is assembled. - Read
oak_runtime_sequenceto trace one composite-wired execution path.
oak_runtime_overview and oak_runtime_sequence describe the same six
phases. The difference is only the level of expansion: oak_runtime_overview
stays at the four-interface layer, while oak_runtime_sequence shows what
happens when those slots are filled by the Composite* implementations from
oak.fine_grained.composites. If either diagram and the code
ever disagree, the documentation should be fixed.
The diagrams are intentionally runtime-oriented. They are not exhaustive
method inventories for the interfaces. For the full surface area
(reset, predict, current_subjective_state, OptionKeyboard, and so on),
use the API reference below. oak_architecture is the broadest inventory
view; oak_runtime_overview and oak_runtime_sequence are narrower and only
show what matters for one OaKAgent.step(...).
Step Walkthrough
Read the method as a pipeline. Each block below corresponds to the next
block of code in OaKAgent.step(...).
1. Perceive
subjective_state = self.perception.update(...)
time_step is the input. It carries observation, reward, terminated,
truncated, and optional info. perception must turn these into the
current subjective_state. Every later call in the step uses this
subjective_state, so your Perception implementation defines what the
agent actually reasons over.
2. Learn
td_errors = self.value_function.update(transition)
self.reactive_policy.update(transition, td_errors)
self.transition_model.update(transition)
Learning starts only once the agent has both a previous subjective_state
and a previous action. The first call to step(...) therefore sets up
memory but cannot yet build a full transition.
The Transition packages the previous/next subjective states, the action,
reward, the termination outcome, and optional info. All three modules
receive it.
value_function.update returns TD errors that reactive_policy.update
uses for policy improvement.
3. Grow
ranked_feature_ids = self.perception.discover_and_rank_features(...)
created_subtasks = self.perception.generate_subtasks(ranked_feature_ids)
self.reactive_policy.ingest_subtasks(created_subtasks)
self.reactive_policy.integrate_options()
self.transition_model.integrate_option_models()
perception proposes new features, ranks them by utility, and generates
subtasks from the most useful ones. reactive_policy turns subtasks into
options. transition_model integrates the latest option models so planning
can reason about them. In the overview diagram this appears as top-level
module calls; in the detailed sequence diagram the same phase is expanded into
Composite* -> fine_grained interface calls.
4. Plan
planning_update = self.transition_model.plan(
subjective_state, self.value_function, self.planning_budget
)
self.reactive_policy.apply_planning_update(planning_update)
transition_model.plan(...) receives the current subjective_state, the
value_function (for state evaluation during search), and a budget. It
returns a PlanningUpdate. reactive_policy is informed about the
planner's output before action selection.
5. Act
action, active_option_id = self.reactive_policy.select_action(
subjective_state, self.option_stop_threshold
)
The reactive policy either continues an active option or makes a fresh
decision. The output is always a primitive action, because that is what
the caller receives in AgentStepResult.
6. Maintain
self.value_function.observe_usage(usage_records)
curation_decision = self.value_function.curate()
self._apply_curation(curation_decision)
Usage records for ranked features and the active option are sent to the
value function for utility tracking. The value function then decides what
to prune. _apply_curation(...) dispatches the decision to the relevant
modules: perception.remove_features(...), reactive_policy.remove_options(...),
reactive_policy.remove_subtasks(...), transition_model.remove_option_models(...),
value_function.remove(...).
Training Loop
OaKAgent.train() provides a standard episode loop so implementations
don't need to rewrite the reset/step/terminate boilerplate:
agent = build_my_agent()
world = MyWorld() # must implement the World protocol
def log_episode(episode, reward, avg_reward, agent):
if episode % 10 == 0:
print(f"episode={episode} reward={reward:.1f} avg={avg_reward:.1f}")
rewards = agent.train(
world,
num_episodes=500,
solved_threshold=475.0, # optional early stopping
episode_logger=log_episode,
)
world.close()
The World protocol requires three methods:
reset() -> TimeStep-- start a new episodestep(action) -> TimeStep-- advance one stepclose() -> None-- release resources (can be a no-op)
If you need custom per-episode logging, pass episode_logger(...). If you
need a fully custom training loop (non-episodic environments, multi-agent
setups, custom control flow), call agent.step(time_step) directly instead.
Implementation Order
If your goal is to get a working agent quickly, implement in this order:
- Make
Perceptionproduce a usefulsubjective_statefromTimeStep. Havediscover_and_rank_featuresreturn a fixed list andgenerate_subtasksreturn empty. - Make
ReactivePolicyreturn valid actions fromselect_action. Have the other methods be no-ops. - Make
ValueFunctionacceptupdateand returnpredictvalues. Havecuratereturn an emptyCurationDecision. - Make
TransitionModelacceptupdateand return a validPlanningUpdatefromplan, even if trivial.
That is enough to satisfy the exact call sequence of OaKAgent.step(...).
After that, you can improve learning quality without changing the basic
wiring.
Repository Examples
The concrete implementations live outside oak on purpose.
That shows the intended usage pattern: the package provides the canonical
OaKAgent coordinator and interfaces, while downstream code provides the
implementations. The generated docs now include the repository-level
examples package alongside the core oak API.
examples/smoke/minimal_oak.pyA full smoke-path implementation using the direct approach. Each of the four interfaces is implemented as a single class with intentionally small behavior.examples/smoke/minimal_oak_fine_grained.pyThe same toy environment built from the fine-grained composite building blocks instead of direct interface implementations.examples/example_01/A fuller learning agent that exercises discovery, perception, planning, value learning, and reactive control together.
To run all repository smoke tests, including the minimal example:
pixi run tests
To inspect the smallest runnable example directly:
from examples.smoke.minimal_oak import build_minimal_agent, run_minimal_episode
agent = build_minimal_agent()
trace = run_minimal_episode(horizon=5)
run_minimal_episode(...) returns a compact trace with the
subjective_state, primitive action, active_option_id,
created_subtasks, and planner output at each step.
Design Constraints
Keep these constraints in mind when you replace the minimal pieces with real ones:
Perceptionshould define a usefulsubjective_statefor the domain. The rest of the agent only sees that representation.ReactivePolicyshould stay focused on choosing between primitive actions and options. It should not absorb the work of planning or prediction.ValueFunctionshould start with one meaningful predictive target before you expand to many General Value Functions.TransitionModelshould make honest predictions. Bounded planning becomes misleading quickly if the model invents certainty it does not have.ValueFunction.curate()should stay conservative until you have stable evidence that a learned structure is safely removable.
API Documentation
99class ContinualLearner: 100 """Mixin for modules whose weights are adapted by meta-learned step sizes. 101 102 In Sutton's OaK architecture, every learned weight has a dedicated 103 step-size parameter adapted via online cross-validation (e.g. IDBD, 104 Sutton 1992; Adam-IDBD, Degris et al. 2024). 105 106 The agent loop calls `update_meta()` on all four modules after each 107 learning step, passing the same error-signals dict. Each module 108 internally decides which signals are relevant and routes them to its 109 per-weight step-size adaptation. 110 111 The default implementation is a no-op so that modules without 112 meta-learning still work unchanged. 113 """ 114 115 def update_meta(self, error_signals: Mapping[str, float]) -> None: 116 """Adapt internal per-weight step sizes given error signals. 117 118 Parameters 119 ---------- 120 error_signals: 121 Named scalar error signals from the current learning step, 122 e.g. `{"main_td_error": 0.05, "reward": 1.0}`. 123 Implementations pick the signals they need and ignore the rest. 124 """
Mixin for modules whose weights are adapted by meta-learned step sizes.
In Sutton's OaK architecture, every learned weight has a dedicated step-size parameter adapted via online cross-validation (e.g. IDBD, Sutton 1992; Adam-IDBD, Degris et al. 2024).
The agent loop calls update_meta() on all four modules after each
learning step, passing the same error-signals dict. Each module
internally decides which signals are relevant and routes them to its
per-weight step-size adaptation.
The default implementation is a no-op so that modules without meta-learning still work unchanged.
115 def update_meta(self, error_signals: Mapping[str, float]) -> None: 116 """Adapt internal per-weight step sizes given error signals. 117 118 Parameters 119 ---------- 120 error_signals: 121 Named scalar error signals from the current learning step, 122 e.g. `{"main_td_error": 0.05, "reward": 1.0}`. 123 Implementations pick the signals they need and ignore the rest. 124 """
Adapt internal per-weight step sizes given error signals.
Parameters
error_signals:
Named scalar error signals from the current learning step,
e.g. {"main_td_error": 0.05, "reward": 1.0}.
Implementations pick the signals they need and ignore the rest.
132class Perception( 133 ContinualLearner, ABC, Generic[ObservationT, ActionT, SubjectiveStateT] 134): 135 """Sutton's Perception: observations → subjective state + feature management. 136 137 Turns raw observations into the agent's **subjective state**, the 138 internal representation that every other module sees. Also discovers, 139 ranks, and manages **features** (learned representational structures 140 that grow over the agent's lifetime) and generates **subtasks** from 141 the most useful ones. 142 143 The finer-grained layer splits this into `StateBuilder`, 144 `FeatureBank`, `FeatureConstructor`, `FeatureRanker`, and 145 `SubtaskGenerator` (see `oak.fine_grained.components`). 146 """ 147 148 @abstractmethod 149 def reset(self) -> None: 150 """Reset all perception state for a new episode.""" 151 raise NotImplementedError 152 153 @abstractmethod 154 def update( 155 self, 156 observation: ObservationT, 157 reward: float, 158 last_action: ActionT | None, 159 ) -> SubjectiveStateT: 160 """Process a new observation and return the updated subjective state.""" 161 raise NotImplementedError 162 163 @abstractmethod 164 def current_subjective_state(self) -> SubjectiveStateT: 165 """Return the most recently computed subjective state.""" 166 raise NotImplementedError 167 168 @abstractmethod 169 def discover_and_rank_features( 170 self, 171 subjective_state: SubjectiveStateT, 172 utility_scores: Sequence[UtilityRecord], 173 feature_budget: int, 174 ) -> Sequence[FeatureId]: 175 """Propose new features, integrate them, and return the top-ranked IDs. 176 177 A typical implementation: 178 179 1. Proposes candidate features from the current subjective state. 180 2. Adds accepted candidates to its internal feature store. 181 3. Ranks all features using the provided utility scores. 182 4. Returns the top feature IDs (up to *feature_budget*). 183 """ 184 raise NotImplementedError 185 186 @abstractmethod 187 def generate_subtasks( 188 self, 189 ranked_feature_ids: Sequence[FeatureId], 190 ) -> Sequence[SubtaskSpec]: 191 """Turn ranked feature IDs into subtask specifications.""" 192 raise NotImplementedError 193 194 @abstractmethod 195 def list_features(self) -> Sequence[FeatureSpec]: 196 """Return all currently tracked features.""" 197 raise NotImplementedError 198 199 @abstractmethod 200 def remove_features(self, feature_ids: Sequence[FeatureId]) -> None: 201 """Remove features by ID (called during curation).""" 202 raise NotImplementedError
Sutton's Perception: observations → subjective state + feature management.
Turns raw observations into the agent's subjective state, the internal representation that every other module sees. Also discovers, ranks, and manages features (learned representational structures that grow over the agent's lifetime) and generates subtasks from the most useful ones.
The finer-grained layer splits this into StateBuilder,
FeatureBank, FeatureConstructor, FeatureRanker, and
SubtaskGenerator (see oak.fine_grained.components).
148 @abstractmethod 149 def reset(self) -> None: 150 """Reset all perception state for a new episode.""" 151 raise NotImplementedError
Reset all perception state for a new episode.
153 @abstractmethod 154 def update( 155 self, 156 observation: ObservationT, 157 reward: float, 158 last_action: ActionT | None, 159 ) -> SubjectiveStateT: 160 """Process a new observation and return the updated subjective state.""" 161 raise NotImplementedError
Process a new observation and return the updated subjective state.
163 @abstractmethod 164 def current_subjective_state(self) -> SubjectiveStateT: 165 """Return the most recently computed subjective state.""" 166 raise NotImplementedError
Return the most recently computed subjective state.
168 @abstractmethod 169 def discover_and_rank_features( 170 self, 171 subjective_state: SubjectiveStateT, 172 utility_scores: Sequence[UtilityRecord], 173 feature_budget: int, 174 ) -> Sequence[FeatureId]: 175 """Propose new features, integrate them, and return the top-ranked IDs. 176 177 A typical implementation: 178 179 1. Proposes candidate features from the current subjective state. 180 2. Adds accepted candidates to its internal feature store. 181 3. Ranks all features using the provided utility scores. 182 4. Returns the top feature IDs (up to *feature_budget*). 183 """ 184 raise NotImplementedError
Propose new features, integrate them, and return the top-ranked IDs.
A typical implementation:
- Proposes candidate features from the current subjective state.
- Adds accepted candidates to its internal feature store.
- Ranks all features using the provided utility scores.
- Returns the top feature IDs (up to feature_budget).
186 @abstractmethod 187 def generate_subtasks( 188 self, 189 ranked_feature_ids: Sequence[FeatureId], 190 ) -> Sequence[SubtaskSpec]: 191 """Turn ranked feature IDs into subtask specifications.""" 192 raise NotImplementedError
Turn ranked feature IDs into subtask specifications.
194 @abstractmethod 195 def list_features(self) -> Sequence[FeatureSpec]: 196 """Return all currently tracked features.""" 197 raise NotImplementedError
Return all currently tracked features.
199 @abstractmethod 200 def remove_features(self, feature_ids: Sequence[FeatureId]) -> None: 201 """Remove features by ID (called during curation).""" 202 raise NotImplementedError
Remove features by ID (called during curation).
261class TransitionModel(ContinualLearner, ABC, Generic[SubjectiveStateT, ActionT, InfoT]): 262 """Sutton's Transition Model: world dynamics + option models + planning. 263 264 Learns from observed transitions, maintains **option models** that 265 predict the effect of temporal abstractions, and runs bounded 266 **planning** using the world model and the value function to produce 267 improvement signals for the reactive policy. 268 269 The finer-grained layer splits this into `WorldModel`, 270 `OptionModelLearner`, individual `OptionModel` objects, and a 271 `Planner` (see `oak.fine_grained.components`). 272 """ 273 274 @abstractmethod 275 def update( 276 self, 277 transition: Transition[ActionT, SubjectiveStateT, InfoT], 278 ) -> None: 279 """Learn from an observed transition. 280 281 This should update both the world model and any option-model learners. 282 """ 283 raise NotImplementedError 284 285 @abstractmethod 286 def integrate_option_models(self) -> None: 287 """Export learned option models and integrate them into the world model. 288 289 Called after option learning so that planning reasons over fresh models. 290 """ 291 raise NotImplementedError 292 293 @abstractmethod 294 def plan( 295 self, 296 subjective_state: SubjectiveStateT, 297 value_function: ValueFunction[SubjectiveStateT, ActionT, InfoT], 298 budget: int, 299 ) -> PlanningUpdate[ActionT]: 300 """Run bounded planning and return improvement signals. 301 302 The planner uses the internal world model together with the supplied 303 *value_function* (for state evaluation) to produce value targets, 304 policy targets, or search statistics. 305 """ 306 raise NotImplementedError 307 308 @abstractmethod 309 def remove_option_models(self, option_ids: Sequence[OptionId]) -> None: 310 """Remove option models by ID (called during curation).""" 311 raise NotImplementedError
Sutton's Transition Model: world dynamics + option models + planning.
Learns from observed transitions, maintains option models that predict the effect of temporal abstractions, and runs bounded planning using the world model and the value function to produce improvement signals for the reactive policy.
The finer-grained layer splits this into WorldModel,
OptionModelLearner, individual OptionModel objects, and a
Planner (see oak.fine_grained.components).
274 @abstractmethod 275 def update( 276 self, 277 transition: Transition[ActionT, SubjectiveStateT, InfoT], 278 ) -> None: 279 """Learn from an observed transition. 280 281 This should update both the world model and any option-model learners. 282 """ 283 raise NotImplementedError
Learn from an observed transition.
This should update both the world model and any option-model learners.
285 @abstractmethod 286 def integrate_option_models(self) -> None: 287 """Export learned option models and integrate them into the world model. 288 289 Called after option learning so that planning reasons over fresh models. 290 """ 291 raise NotImplementedError
Export learned option models and integrate them into the world model.
Called after option learning so that planning reasons over fresh models.
293 @abstractmethod 294 def plan( 295 self, 296 subjective_state: SubjectiveStateT, 297 value_function: ValueFunction[SubjectiveStateT, ActionT, InfoT], 298 budget: int, 299 ) -> PlanningUpdate[ActionT]: 300 """Run bounded planning and return improvement signals. 301 302 The planner uses the internal world model together with the supplied 303 *value_function* (for state evaluation) to produce value targets, 304 policy targets, or search statistics. 305 """ 306 raise NotImplementedError
Run bounded planning and return improvement signals.
The planner uses the internal world model together with the supplied value_function (for state evaluation) to produce value targets, policy targets, or search statistics.
308 @abstractmethod 309 def remove_option_models(self, option_ids: Sequence[OptionId]) -> None: 310 """Remove option models by ID (called during curation).""" 311 raise NotImplementedError
Remove option models by ID (called during curation).
205class ValueFunction(ContinualLearner, ABC, Generic[SubjectiveStateT, ActionT, InfoT]): 206 """Sutton's Value Function: value learning + utility assessment + curation. 207 208 Learns **predictive value signals** (TD errors, GVF predictions) from 209 observed transitions and predicts cumulative signals for any given 210 subjective state. Also assesses the **utility** of the agent's learned 211 structures (features, options, models) and produces concrete keep/drop 212 **curation** decisions. 213 214 The finer-grained layer splits this into `ValueEstimator`, 215 `GeneralValueFunctionLearner`, `UtilityAssessor`, `Curator`, and 216 `MetaStepSizeLearner` (see `oak.fine_grained.components`). 217 """ 218 219 @abstractmethod 220 def update( 221 self, 222 transition: Transition[ActionT, SubjectiveStateT, InfoT], 223 *, 224 planning: bool = False, 225 ) -> Mapping[GeneralValueFunctionId, float]: 226 """Learn from a transition and return TD-error signals.""" 227 raise NotImplementedError 228 229 @abstractmethod 230 def predict( 231 self, 232 subjective_state: SubjectiveStateT, 233 ) -> Mapping[GeneralValueFunctionId, float]: 234 """Predict values for the given subjective state.""" 235 raise NotImplementedError 236 237 @abstractmethod 238 def observe_usage(self, usage_records: Sequence[UsageRecord]) -> None: 239 """Record usage evidence for utility assessment.""" 240 raise NotImplementedError 241 242 @abstractmethod 243 def utility_scores(self) -> Sequence[UtilityRecord]: 244 """Return current utility estimates for all tracked structures.""" 245 raise NotImplementedError 246 247 @abstractmethod 248 def curate(self) -> CurationDecision: 249 """Decide which learned structures to drop.""" 250 raise NotImplementedError 251 252 @abstractmethod 253 def remove( 254 self, 255 general_value_function_ids: Sequence[GeneralValueFunctionId], 256 ) -> None: 257 """Remove value functions by ID (called during curation).""" 258 raise NotImplementedError
Sutton's Value Function: value learning + utility assessment + curation.
Learns predictive value signals (TD errors, GVF predictions) from observed transitions and predicts cumulative signals for any given subjective state. Also assesses the utility of the agent's learned structures (features, options, models) and produces concrete keep/drop curation decisions.
The finer-grained layer splits this into ValueEstimator,
GeneralValueFunctionLearner, UtilityAssessor, Curator, and
MetaStepSizeLearner (see oak.fine_grained.components).
219 @abstractmethod 220 def update( 221 self, 222 transition: Transition[ActionT, SubjectiveStateT, InfoT], 223 *, 224 planning: bool = False, 225 ) -> Mapping[GeneralValueFunctionId, float]: 226 """Learn from a transition and return TD-error signals.""" 227 raise NotImplementedError
Learn from a transition and return TD-error signals.
229 @abstractmethod 230 def predict( 231 self, 232 subjective_state: SubjectiveStateT, 233 ) -> Mapping[GeneralValueFunctionId, float]: 234 """Predict values for the given subjective state.""" 235 raise NotImplementedError
Predict values for the given subjective state.
237 @abstractmethod 238 def observe_usage(self, usage_records: Sequence[UsageRecord]) -> None: 239 """Record usage evidence for utility assessment.""" 240 raise NotImplementedError
Record usage evidence for utility assessment.
242 @abstractmethod 243 def utility_scores(self) -> Sequence[UtilityRecord]: 244 """Return current utility estimates for all tracked structures.""" 245 raise NotImplementedError
Return current utility estimates for all tracked structures.
247 @abstractmethod 248 def curate(self) -> CurationDecision: 249 """Decide which learned structures to drop.""" 250 raise NotImplementedError
Decide which learned structures to drop.
252 @abstractmethod 253 def remove( 254 self, 255 general_value_function_ids: Sequence[GeneralValueFunctionId], 256 ) -> None: 257 """Remove value functions by ID (called during curation).""" 258 raise NotImplementedError
Remove value functions by ID (called during curation).
314class ReactivePolicy(ContinualLearner, ABC, Generic[SubjectiveStateT, ActionT, InfoT]): 315 """Sutton's Reactive Policy: action selection + option management. 316 317 Selects **actions**, primitive or temporal abstractions (options), 318 based on the current subjective state. Manages the **option library** 319 and **option learning** pipeline, and integrates **planning updates** 320 into decision-making. 321 322 The finer-grained layer splits this into `ActionSelector`, 323 `OptionLibrary`, and `OptionLearner` 324 (see `oak.fine_grained.components`). 325 """ 326 327 @abstractmethod 328 def update( 329 self, 330 transition: Transition[ActionT, SubjectiveStateT, InfoT], 331 td_errors: Mapping[GeneralValueFunctionId, float], 332 ) -> None: 333 """Update the policy and option learners from an observed transition.""" 334 raise NotImplementedError 335 336 @abstractmethod 337 def apply_planning_update(self, update: PlanningUpdate[ActionT]) -> None: 338 """Integrate planning improvement signals into the policy.""" 339 raise NotImplementedError 340 341 @abstractmethod 342 def ingest_subtasks(self, subtasks: Sequence[SubtaskSpec]) -> None: 343 """Feed newly created subtasks into the option learner.""" 344 raise NotImplementedError 345 346 @abstractmethod 347 def integrate_options(self) -> None: 348 """Export learned options into the option library.""" 349 raise NotImplementedError 350 351 @abstractmethod 352 def select_action( 353 self, 354 subjective_state: SubjectiveStateT, 355 option_stop_threshold: float, 356 ) -> tuple[ActionT, OptionId | None]: 357 """Choose a primitive action, possibly by continuing an active option. 358 359 Returns a `(primitive_action, active_option_id)` pair. When no 360 option is active, *active_option_id* is `None`. 361 """ 362 raise NotImplementedError 363 364 @abstractmethod 365 def clear_active_option(self) -> None: 366 """Clear the currently executing option (e.g. at episode boundaries).""" 367 raise NotImplementedError 368 369 @abstractmethod 370 def remove_options(self, option_ids: Sequence[OptionId]) -> None: 371 """Remove options by ID (called during curation).""" 372 raise NotImplementedError 373 374 @abstractmethod 375 def remove_subtasks(self, subtask_ids: Sequence[SubtaskId]) -> None: 376 """Remove subtasks by ID (called during curation).""" 377 raise NotImplementedError
Sutton's Reactive Policy: action selection + option management.
Selects actions, primitive or temporal abstractions (options), based on the current subjective state. Manages the option library and option learning pipeline, and integrates planning updates into decision-making.
The finer-grained layer splits this into ActionSelector,
OptionLibrary, and OptionLearner
(see oak.fine_grained.components).
327 @abstractmethod 328 def update( 329 self, 330 transition: Transition[ActionT, SubjectiveStateT, InfoT], 331 td_errors: Mapping[GeneralValueFunctionId, float], 332 ) -> None: 333 """Update the policy and option learners from an observed transition.""" 334 raise NotImplementedError
Update the policy and option learners from an observed transition.
336 @abstractmethod 337 def apply_planning_update(self, update: PlanningUpdate[ActionT]) -> None: 338 """Integrate planning improvement signals into the policy.""" 339 raise NotImplementedError
Integrate planning improvement signals into the policy.
341 @abstractmethod 342 def ingest_subtasks(self, subtasks: Sequence[SubtaskSpec]) -> None: 343 """Feed newly created subtasks into the option learner.""" 344 raise NotImplementedError
Feed newly created subtasks into the option learner.
346 @abstractmethod 347 def integrate_options(self) -> None: 348 """Export learned options into the option library.""" 349 raise NotImplementedError
Export learned options into the option library.
351 @abstractmethod 352 def select_action( 353 self, 354 subjective_state: SubjectiveStateT, 355 option_stop_threshold: float, 356 ) -> tuple[ActionT, OptionId | None]: 357 """Choose a primitive action, possibly by continuing an active option. 358 359 Returns a `(primitive_action, active_option_id)` pair. When no 360 option is active, *active_option_id* is `None`. 361 """ 362 raise NotImplementedError
Choose a primitive action, possibly by continuing an active option.
Returns a (primitive_action, active_option_id) pair. When no
option is active, active_option_id is None.
364 @abstractmethod 365 def clear_active_option(self) -> None: 366 """Clear the currently executing option (e.g. at episode boundaries).""" 367 raise NotImplementedError
Clear the currently executing option (e.g. at episode boundaries).
369 @abstractmethod 370 def remove_options(self, option_ids: Sequence[OptionId]) -> None: 371 """Remove options by ID (called during curation).""" 372 raise NotImplementedError
Remove options by ID (called during curation).
374 @abstractmethod 375 def remove_subtasks(self, subtask_ids: Sequence[SubtaskId]) -> None: 376 """Remove subtasks by ID (called during curation).""" 377 raise NotImplementedError
Remove subtasks by ID (called during curation).
54@dataclass 55class OaKAgent(Generic[ObservationT, ActionT, SubjectiveStateT, InfoT]): 56 """Coordinates one full OaK step across the four modules. 57 58 The agent is a wiring object: you provide concrete implementations of 59 `Perception`, `TransitionModel`, `ValueFunction`, and 60 `ReactivePolicy`, and `OaKAgent` ensures they are called in a 61 consistent order. 62 """ 63 64 perception: Perception[ObservationT, ActionT, SubjectiveStateT] 65 transition_model: TransitionModel[SubjectiveStateT, ActionT, InfoT] 66 value_function: ValueFunction[SubjectiveStateT, ActionT, InfoT] 67 reactive_policy: ReactivePolicy[SubjectiveStateT, ActionT, InfoT] 68 69 planning_budget: int = 4 70 feature_budget: int = 4 71 option_stop_threshold: float = 0.5 72 73 last_action: ActionT | None = None 74 last_subjective_state: SubjectiveStateT | None = None 75 last_active_option_id: OptionId | None = None 76 77 def __init__( 78 self, 79 perception: Perception[ObservationT, ActionT, SubjectiveStateT], 80 transition_model: TransitionModel[SubjectiveStateT, ActionT, InfoT], 81 value_function: ValueFunction[SubjectiveStateT, ActionT, InfoT], 82 reactive_policy: ReactivePolicy[SubjectiveStateT, ActionT, InfoT], 83 planning_budget: int = 4, 84 feature_budget: int = 4, 85 option_stop_threshold: float = 0.5, 86 ): 87 self.perception = perception 88 self.transition_model = transition_model 89 self.value_function = value_function 90 self.reactive_policy = reactive_policy 91 self.planning_budget = planning_budget 92 self.feature_budget = feature_budget 93 self.option_stop_threshold = option_stop_threshold 94 self.last_action = None 95 self.last_subjective_state = None 96 self.last_active_option_id = None 97 98 def __post_init__(self): 99 """Validate that the modules are compatible.""" 100 if self.planning_budget < 1: 101 raise ValueError("Planning budget must be at least 1.") 102 if self.feature_budget < 1: 103 raise ValueError("Feature budget must be at least 1.") 104 if self.option_stop_threshold < 0 or self.option_stop_threshold > 1: 105 raise ValueError("Option stop threshold must be in [0, 1].") 106 107 def reset(self) -> None: 108 """Clear transient execution memory.""" 109 self.perception.reset() 110 self.reactive_policy.clear_active_option() 111 self.last_action = None 112 self.last_subjective_state = None 113 self.last_active_option_id = None 114 115 def step( 116 self, time_step: TimeStep[ObservationT, InfoT] 117 ) -> AgentStepResult[ActionT, SubjectiveStateT]: 118 """Run one temporally uniform agent step. 119 120 The step follows six phases: perceive, learn, grow, plan, act, maintain. 121 """ 122 123 # ================== 1. Perceive ================= # 124 subjective_state = self.perception.update( 125 observation=time_step.observation, 126 reward=time_step.reward, 127 last_action=self.last_action, 128 ) 129 130 created_subtasks: Sequence[SubtaskSpec] = () 131 ranked_feature_ids: Sequence[FeatureId] = () 132 planning_update: PlanningUpdate[ActionT] | None = None 133 curation_decision: CurationDecision | None = None 134 135 # ================== 2. Learn ================== # 136 if self.last_subjective_state is not None and self.last_action is not None: 137 transition = Transition( 138 subjective_state=self.last_subjective_state, 139 action=self.last_action, 140 reward=time_step.reward, 141 next_subjective_state=subjective_state, 142 terminated=time_step.terminated or time_step.truncated, 143 option_id=self.last_active_option_id, 144 info=time_step.info, 145 ) 146 td_errors = self.value_function.update(transition) 147 self.reactive_policy.update(transition, td_errors) 148 self.transition_model.update(transition) 149 150 # Meta step-size adaptation (Sutton's IDBD / online cross-validation) 151 meta_signals = dict(td_errors) 152 meta_signals["reward"] = transition.reward 153 self.perception.update_meta(meta_signals) 154 self.value_function.update_meta(meta_signals) 155 self.reactive_policy.update_meta(meta_signals) 156 self.transition_model.update_meta(meta_signals) 157 158 # ================== 3. Grow ================== # 159 ranked_feature_ids = self.perception.discover_and_rank_features( 160 subjective_state, 161 self.value_function.utility_scores(), 162 self.feature_budget, 163 ) 164 if ranked_feature_ids: 165 created_subtasks = self.perception.generate_subtasks(ranked_feature_ids) 166 if created_subtasks: 167 self.reactive_policy.ingest_subtasks(created_subtasks) 168 169 self.reactive_policy.integrate_options() 170 self.transition_model.integrate_option_models() 171 172 # ================== 4. Plan ================== # 173 planning_update = self.transition_model.plan( 174 subjective_state, self.value_function, self.planning_budget 175 ) 176 self.reactive_policy.apply_planning_update(planning_update) 177 178 # ================== 5. Act ================== # 179 action, active_option_id = self.reactive_policy.select_action( 180 subjective_state, self.option_stop_threshold 181 ) 182 183 # ================== 6. Maintain ================== # 184 usage_records = self._build_usage_records(ranked_feature_ids, active_option_id) 185 if usage_records: 186 self.value_function.observe_usage(usage_records) 187 188 curation_decision = self.value_function.curate() 189 self._apply_curation(curation_decision) 190 191 # ================== Update Memory ================== # 192 self.last_subjective_state = subjective_state 193 self.last_action = action 194 self.last_active_option_id = active_option_id 195 196 if time_step.terminated or time_step.truncated: 197 self.reactive_policy.clear_active_option() 198 199 return AgentStepResult( 200 action=action, 201 subjective_state=subjective_state, 202 active_option_id=active_option_id, 203 planning_update=planning_update, 204 created_subtasks=created_subtasks, 205 curation_decision=curation_decision, 206 ) 207 208 # ── training loop ───────────────────────────────────────────────── 209 210 def train( 211 self, 212 world: World[ObservationT, ActionT, InfoT], 213 *, # enforce keyword arguments after for clarity 214 num_episodes: int = 500, 215 average_window: int = 100, 216 solved_threshold: float | None = None, 217 episode_logger: Callable[[int, float, float, Self], None] | None = None, 218 episode_trace_logger: Callable[ 219 [EpisodeTrace[ObservationT, ActionT, SubjectiveStateT, InfoT]], None 220 ] 221 | None = None, 222 trace_selector: Callable[[int, int], bool] | None = None, 223 capture_rendered_frames: bool = False, 224 ) -> list[float]: 225 """Run the standard OaK episode loop on the given world. 226 227 Parameters 228 ---------- 229 world: 230 An environment implementing the `World` protocol. 231 num_episodes: 232 Maximum number of training episodes. 233 average_window: 234 Number of recent episodes to average for performance tracking. 235 solved_threshold: 236 If set, stop early when the average reward over the last `average_window` 237 episodes reaches this value. 238 episode_logger: 239 Optional callback `(episode, episode_reward, avg_reward, agent)` 240 called after each episode. Use this to own all per-episode logging 241 or other training-side effects at the call site. 242 episode_trace_logger: 243 Optional richer callback receiving an `EpisodeTrace` with the 244 training world, agent, selected step records, and optionally 245 rendered frames. 246 trace_selector: 247 Optional predicate `(episode, num_episodes) -> bool` used to decide 248 which episodes should produce an `EpisodeTrace`. If omitted and an 249 `episode_trace_logger` is provided, all episodes are traced. 250 capture_rendered_frames: 251 When `True`, collect frames from worlds that implement the optional 252 `render_frame()` capability for traced episodes. 253 254 Returns 255 ------- 256 list[float] 257 Per-episode reward history. 258 """ 259 if average_window < 1: 260 raise ValueError("average_window must be at least 1.") 261 262 reward_history: list[float] = [] 263 264 for episode in range(num_episodes): 265 time_step = world.reset() 266 initial_time_step = time_step 267 self.reset() 268 episode_reward = 0.0 269 step_count = 0 270 capture_trace = episode_trace_logger is not None and ( 271 trace_selector(episode, num_episodes) if trace_selector is not None else True 272 ) 273 traced_steps: list[ 274 EpisodeStepRecord[ObservationT, ActionT, SubjectiveStateT, InfoT] 275 ] = [] 276 traced_frames: list[object] = [] 277 278 if capture_trace and capture_rendered_frames: 279 frame = self._render_frame(world) 280 if frame is not None: 281 traced_frames.append(frame) 282 283 while True: 284 result = self.step(time_step) 285 286 if time_step.terminated or time_step.truncated: 287 break 288 289 next_time_step = world.step(result.action) 290 episode_reward += next_time_step.reward 291 if capture_trace: 292 traced_steps.append( 293 EpisodeStepRecord( 294 step_index=step_count, 295 time_step=time_step, 296 action=result.action, 297 next_time_step=next_time_step, 298 active_option_id=result.active_option_id, 299 planning_update=result.planning_update, 300 created_subtasks=result.created_subtasks, 301 ) 302 ) 303 if capture_rendered_frames: 304 frame = self._render_frame(world) 305 if frame is not None: 306 traced_frames.append(frame) 307 time_step = next_time_step 308 step_count += 1 309 310 reward_history.append(episode_reward) 311 recent_window = reward_history[-average_window:] 312 avg_reward = sum(recent_window) / len(recent_window) 313 314 solved = ( 315 solved_threshold is not None 316 and len(reward_history) >= average_window 317 and avg_reward >= solved_threshold 318 ) 319 320 if episode_logger is not None: 321 episode_logger(episode, episode_reward, avg_reward, self) 322 323 if episode_trace_logger is not None and capture_trace: 324 episode_trace_logger( 325 EpisodeTrace( 326 episode=episode, 327 episode_reward=episode_reward, 328 avg_reward=avg_reward, 329 step_count=step_count, 330 solved=solved, 331 initial_time_step=initial_time_step, 332 final_time_step=time_step, 333 steps=tuple(traced_steps), 334 frames=tuple(traced_frames), 335 world=world, 336 agent=self, 337 metadata={ 338 "num_episodes": num_episodes, 339 "capture_rendered_frames": capture_rendered_frames, 340 }, 341 ) 342 ) 343 344 for component in ( 345 self.perception, 346 self.value_function, 347 self.transition_model, 348 self.reactive_policy, 349 ): 350 end_episode = getattr(component, "end_episode", None) 351 if callable(end_episode): 352 end_episode() 353 354 if solved: 355 break 356 357 return reward_history 358 359 # ── private helpers ────────────────────────────────────────────── 360 361 def _build_usage_records( 362 self, 363 ranked_feature_ids: Sequence[FeatureId], 364 active_option_id: OptionId | None, 365 ) -> Sequence[UsageRecord]: 366 """Build minimal utility-accounting observations for the current step.""" 367 usage_records = [ 368 UsageRecord(ComponentKind.FEATURE, feature_id) 369 for feature_id in ranked_feature_ids 370 ] 371 if active_option_id is not None: 372 usage_records.append(UsageRecord(ComponentKind.OPTION, active_option_id)) 373 return tuple(usage_records) 374 375 def _apply_curation(self, decision: CurationDecision) -> None: 376 """Dispatch curation decisions to the relevant modules.""" 377 if decision.drop_features: 378 self.perception.remove_features(decision.drop_features) 379 if decision.drop_subtasks: 380 self.reactive_policy.remove_subtasks(decision.drop_subtasks) 381 if decision.drop_options: 382 self.reactive_policy.remove_options(decision.drop_options) 383 if decision.drop_option_models: 384 self.transition_model.remove_option_models(decision.drop_option_models) 385 if decision.drop_general_value_functions: 386 self.value_function.remove(decision.drop_general_value_functions) 387 388 def _render_frame( 389 self, 390 world: World[ObservationT, ActionT, InfoT], 391 ) -> object | None: 392 """Capture one world frame when optional rendering is available.""" 393 if not isinstance(world, RenderableWorld): 394 return None 395 return world.render_frame()
Coordinates one full OaK step across the four modules.
The agent is a wiring object: you provide concrete implementations of
Perception, TransitionModel, ValueFunction, and
ReactivePolicy, and OaKAgent ensures they are called in a
consistent order.
77 def __init__( 78 self, 79 perception: Perception[ObservationT, ActionT, SubjectiveStateT], 80 transition_model: TransitionModel[SubjectiveStateT, ActionT, InfoT], 81 value_function: ValueFunction[SubjectiveStateT, ActionT, InfoT], 82 reactive_policy: ReactivePolicy[SubjectiveStateT, ActionT, InfoT], 83 planning_budget: int = 4, 84 feature_budget: int = 4, 85 option_stop_threshold: float = 0.5, 86 ): 87 self.perception = perception 88 self.transition_model = transition_model 89 self.value_function = value_function 90 self.reactive_policy = reactive_policy 91 self.planning_budget = planning_budget 92 self.feature_budget = feature_budget 93 self.option_stop_threshold = option_stop_threshold 94 self.last_action = None 95 self.last_subjective_state = None 96 self.last_active_option_id = None
107 def reset(self) -> None: 108 """Clear transient execution memory.""" 109 self.perception.reset() 110 self.reactive_policy.clear_active_option() 111 self.last_action = None 112 self.last_subjective_state = None 113 self.last_active_option_id = None
Clear transient execution memory.
115 def step( 116 self, time_step: TimeStep[ObservationT, InfoT] 117 ) -> AgentStepResult[ActionT, SubjectiveStateT]: 118 """Run one temporally uniform agent step. 119 120 The step follows six phases: perceive, learn, grow, plan, act, maintain. 121 """ 122 123 # ================== 1. Perceive ================= # 124 subjective_state = self.perception.update( 125 observation=time_step.observation, 126 reward=time_step.reward, 127 last_action=self.last_action, 128 ) 129 130 created_subtasks: Sequence[SubtaskSpec] = () 131 ranked_feature_ids: Sequence[FeatureId] = () 132 planning_update: PlanningUpdate[ActionT] | None = None 133 curation_decision: CurationDecision | None = None 134 135 # ================== 2. Learn ================== # 136 if self.last_subjective_state is not None and self.last_action is not None: 137 transition = Transition( 138 subjective_state=self.last_subjective_state, 139 action=self.last_action, 140 reward=time_step.reward, 141 next_subjective_state=subjective_state, 142 terminated=time_step.terminated or time_step.truncated, 143 option_id=self.last_active_option_id, 144 info=time_step.info, 145 ) 146 td_errors = self.value_function.update(transition) 147 self.reactive_policy.update(transition, td_errors) 148 self.transition_model.update(transition) 149 150 # Meta step-size adaptation (Sutton's IDBD / online cross-validation) 151 meta_signals = dict(td_errors) 152 meta_signals["reward"] = transition.reward 153 self.perception.update_meta(meta_signals) 154 self.value_function.update_meta(meta_signals) 155 self.reactive_policy.update_meta(meta_signals) 156 self.transition_model.update_meta(meta_signals) 157 158 # ================== 3. Grow ================== # 159 ranked_feature_ids = self.perception.discover_and_rank_features( 160 subjective_state, 161 self.value_function.utility_scores(), 162 self.feature_budget, 163 ) 164 if ranked_feature_ids: 165 created_subtasks = self.perception.generate_subtasks(ranked_feature_ids) 166 if created_subtasks: 167 self.reactive_policy.ingest_subtasks(created_subtasks) 168 169 self.reactive_policy.integrate_options() 170 self.transition_model.integrate_option_models() 171 172 # ================== 4. Plan ================== # 173 planning_update = self.transition_model.plan( 174 subjective_state, self.value_function, self.planning_budget 175 ) 176 self.reactive_policy.apply_planning_update(planning_update) 177 178 # ================== 5. Act ================== # 179 action, active_option_id = self.reactive_policy.select_action( 180 subjective_state, self.option_stop_threshold 181 ) 182 183 # ================== 6. Maintain ================== # 184 usage_records = self._build_usage_records(ranked_feature_ids, active_option_id) 185 if usage_records: 186 self.value_function.observe_usage(usage_records) 187 188 curation_decision = self.value_function.curate() 189 self._apply_curation(curation_decision) 190 191 # ================== Update Memory ================== # 192 self.last_subjective_state = subjective_state 193 self.last_action = action 194 self.last_active_option_id = active_option_id 195 196 if time_step.terminated or time_step.truncated: 197 self.reactive_policy.clear_active_option() 198 199 return AgentStepResult( 200 action=action, 201 subjective_state=subjective_state, 202 active_option_id=active_option_id, 203 planning_update=planning_update, 204 created_subtasks=created_subtasks, 205 curation_decision=curation_decision, 206 )
Run one temporally uniform agent step.
The step follows six phases: perceive, learn, grow, plan, act, maintain.
210 def train( 211 self, 212 world: World[ObservationT, ActionT, InfoT], 213 *, # enforce keyword arguments after for clarity 214 num_episodes: int = 500, 215 average_window: int = 100, 216 solved_threshold: float | None = None, 217 episode_logger: Callable[[int, float, float, Self], None] | None = None, 218 episode_trace_logger: Callable[ 219 [EpisodeTrace[ObservationT, ActionT, SubjectiveStateT, InfoT]], None 220 ] 221 | None = None, 222 trace_selector: Callable[[int, int], bool] | None = None, 223 capture_rendered_frames: bool = False, 224 ) -> list[float]: 225 """Run the standard OaK episode loop on the given world. 226 227 Parameters 228 ---------- 229 world: 230 An environment implementing the `World` protocol. 231 num_episodes: 232 Maximum number of training episodes. 233 average_window: 234 Number of recent episodes to average for performance tracking. 235 solved_threshold: 236 If set, stop early when the average reward over the last `average_window` 237 episodes reaches this value. 238 episode_logger: 239 Optional callback `(episode, episode_reward, avg_reward, agent)` 240 called after each episode. Use this to own all per-episode logging 241 or other training-side effects at the call site. 242 episode_trace_logger: 243 Optional richer callback receiving an `EpisodeTrace` with the 244 training world, agent, selected step records, and optionally 245 rendered frames. 246 trace_selector: 247 Optional predicate `(episode, num_episodes) -> bool` used to decide 248 which episodes should produce an `EpisodeTrace`. If omitted and an 249 `episode_trace_logger` is provided, all episodes are traced. 250 capture_rendered_frames: 251 When `True`, collect frames from worlds that implement the optional 252 `render_frame()` capability for traced episodes. 253 254 Returns 255 ------- 256 list[float] 257 Per-episode reward history. 258 """ 259 if average_window < 1: 260 raise ValueError("average_window must be at least 1.") 261 262 reward_history: list[float] = [] 263 264 for episode in range(num_episodes): 265 time_step = world.reset() 266 initial_time_step = time_step 267 self.reset() 268 episode_reward = 0.0 269 step_count = 0 270 capture_trace = episode_trace_logger is not None and ( 271 trace_selector(episode, num_episodes) if trace_selector is not None else True 272 ) 273 traced_steps: list[ 274 EpisodeStepRecord[ObservationT, ActionT, SubjectiveStateT, InfoT] 275 ] = [] 276 traced_frames: list[object] = [] 277 278 if capture_trace and capture_rendered_frames: 279 frame = self._render_frame(world) 280 if frame is not None: 281 traced_frames.append(frame) 282 283 while True: 284 result = self.step(time_step) 285 286 if time_step.terminated or time_step.truncated: 287 break 288 289 next_time_step = world.step(result.action) 290 episode_reward += next_time_step.reward 291 if capture_trace: 292 traced_steps.append( 293 EpisodeStepRecord( 294 step_index=step_count, 295 time_step=time_step, 296 action=result.action, 297 next_time_step=next_time_step, 298 active_option_id=result.active_option_id, 299 planning_update=result.planning_update, 300 created_subtasks=result.created_subtasks, 301 ) 302 ) 303 if capture_rendered_frames: 304 frame = self._render_frame(world) 305 if frame is not None: 306 traced_frames.append(frame) 307 time_step = next_time_step 308 step_count += 1 309 310 reward_history.append(episode_reward) 311 recent_window = reward_history[-average_window:] 312 avg_reward = sum(recent_window) / len(recent_window) 313 314 solved = ( 315 solved_threshold is not None 316 and len(reward_history) >= average_window 317 and avg_reward >= solved_threshold 318 ) 319 320 if episode_logger is not None: 321 episode_logger(episode, episode_reward, avg_reward, self) 322 323 if episode_trace_logger is not None and capture_trace: 324 episode_trace_logger( 325 EpisodeTrace( 326 episode=episode, 327 episode_reward=episode_reward, 328 avg_reward=avg_reward, 329 step_count=step_count, 330 solved=solved, 331 initial_time_step=initial_time_step, 332 final_time_step=time_step, 333 steps=tuple(traced_steps), 334 frames=tuple(traced_frames), 335 world=world, 336 agent=self, 337 metadata={ 338 "num_episodes": num_episodes, 339 "capture_rendered_frames": capture_rendered_frames, 340 }, 341 ) 342 ) 343 344 for component in ( 345 self.perception, 346 self.value_function, 347 self.transition_model, 348 self.reactive_policy, 349 ): 350 end_episode = getattr(component, "end_episode", None) 351 if callable(end_episode): 352 end_episode() 353 354 if solved: 355 break 356 357 return reward_history
Run the standard OaK episode loop on the given world.
Parameters
world:
An environment implementing the World protocol.
num_episodes:
Maximum number of training episodes.
average_window:
Number of recent episodes to average for performance tracking.
solved_threshold:
If set, stop early when the average reward over the last average_window
episodes reaches this value.
episode_logger:
Optional callback (episode, episode_reward, avg_reward, agent)
called after each episode. Use this to own all per-episode logging
or other training-side effects at the call site.
episode_trace_logger:
Optional richer callback receiving an EpisodeTrace with the
training world, agent, selected step records, and optionally
rendered frames.
trace_selector:
Optional predicate (episode, num_episodes) -> bool used to decide
which episodes should produce an EpisodeTrace. If omitted and an
episode_trace_logger is provided, all episodes are traced.
capture_rendered_frames:
When True, collect frames from worlds that implement the optional
render_frame() capability for traced episodes.
Returns
list[float] Per-episode reward history.
64@runtime_checkable 65class World(Protocol[WorldObservationT, WorldActionT, WorldInfoT]): 66 """Minimal environment protocol. 67 68 A `World` may wrap a simulator, a benchmark environment, or a custom 69 continual data source. The protocol is intentionally small so the 70 package does not depend on a specific environment library. 71 72 Implement this protocol for any environment you want to use with 73 `OaKAgent.train()`. 74 """ 75 76 def reset(self) -> TimeStep[WorldObservationT, WorldInfoT]: ... 77 78 def step( 79 self, action: WorldActionT 80 ) -> TimeStep[WorldObservationT, WorldInfoT]: ... 81 82 def close(self) -> None: 83 """Release environment resources. Default is a no-op.""" 84 ...
Minimal environment protocol.
A World may wrap a simulator, a benchmark environment, or a custom
continual data source. The protocol is intentionally small so the
package does not depend on a specific environment library.
Implement this protocol for any environment you want to use with
OaKAgent.train().
1957def _no_init_or_replace_init(self, *args, **kwargs): 1958 cls = type(self) 1959 1960 if cls._is_protocol: 1961 raise TypeError('Protocols cannot be instantiated') 1962 1963 # Already using a custom `__init__`. No need to calculate correct 1964 # `__init__` to call. This can lead to RecursionError. See bpo-45121. 1965 if cls.__init__ is not _no_init_or_replace_init: 1966 return 1967 1968 # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`. 1969 # The first instantiation of the subclass will call `_no_init_or_replace_init` which 1970 # searches for a proper new `__init__` in the MRO. The new `__init__` 1971 # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent 1972 # instantiation of the protocol subclass will thus use the new 1973 # `__init__` and no longer call `_no_init_or_replace_init`. 1974 for base in cls.__mro__: 1975 init = base.__dict__.get('__init__', _no_init_or_replace_init) 1976 if init is not _no_init_or_replace_init: 1977 cls.__init__ = init 1978 break 1979 else: 1980 # should not happen 1981 cls.__init__ = object.__init__ 1982 1983 cls.__init__(self, *args, **kwargs)
76 def reset(self) -> TimeStep[WorldObservationT, WorldInfoT]: ...
246@dataclass(slots=True, frozen=True) 247class AgentStepResult(Generic[ActionT, SubjectiveStateT]): 248 """Observable result of one OaK agent step. 249 250 This is the compact object a caller receives after stepping the agent. It 251 includes the primitive action actually executed, the current subjective 252 state, and any structures or planning signals created during that step. 253 """ 254 255 action: ActionT 256 subjective_state: SubjectiveStateT 257 active_option_id: OptionId | None = None 258 planning_update: PlanningUpdate[ActionT] | None = None 259 created_subtasks: Sequence[SubtaskSpec] = field(default_factory=tuple) 260 curation_decision: CurationDecision | None = None
Observable result of one OaK agent step.
This is the compact object a caller receives after stepping the agent. It includes the primitive action actually executed, the current subjective state, and any structures or planning signals created during that step.
60class ComponentKind(str, Enum): 61 """Kinds of learnable or managed elements in the architecture.""" 62 63 FEATURE = "feature" 64 SUBTASK = "subtask" 65 OPTION = "option" 66 VALUE_FUNCTION = "value_function" 67 OPTION_MODEL = "option_model" 68 TRANSITION_MODEL = "transition_model" 69 POLICY = "policy" 70 PERCEPTION = "perception" 71 PLANNER = "planner"
Kinds of learnable or managed elements in the architecture.
232@dataclass(slots=True, frozen=True) 233class CurationDecision: 234 """Pruning decision returned by the curator.""" 235 236 drop_features: Sequence[FeatureId] = field(default_factory=tuple) 237 drop_subtasks: Sequence[SubtaskId] = field(default_factory=tuple) 238 drop_options: Sequence[OptionId] = field(default_factory=tuple) 239 drop_option_models: Sequence[OptionId] = field(default_factory=tuple) 240 drop_general_value_functions: Sequence[GeneralValueFunctionId] = field( 241 default_factory=tuple 242 ) 243 notes: StructuredPayload = field(default_factory=dict)
Pruning decision returned by the curator.
130@dataclass(slots=True, frozen=True) 131class FeatureCandidate: 132 """A proposed feature that may be admitted into the feature bank.""" 133 134 feature_id: FeatureId 135 name: str 136 origin: str 137 description: str = "" 138 metadata: OpenPayload = field(default_factory=dict)
A proposed feature that may be admitted into the feature bank.
120@dataclass(slots=True, frozen=True) 121class FeatureSpec: 122 """Metadata describing a feature tracked by the agent.""" 123 124 feature_id: FeatureId 125 name: str 126 description: str = "" 127 metadata: OpenPayload = field(default_factory=dict)
Metadata describing a feature tracked by the agent.
141@dataclass(slots=True, frozen=True) 142class GeneralValueFunctionSpec(Generic[ActionT, SubjectiveStateT, InfoT]): 143 """General value function specification.""" 144 145 general_value_function_id: GeneralValueFunctionId 146 name: str 147 cumulant: ScalarSignal 148 continuation: ContinuationFn 149 termination_value: TerminationValueFn 150 metadata: OpenPayload = field(default_factory=dict)
General value function specification.
192@dataclass(slots=True, frozen=True) 193class ModelPrediction(Generic[SubjectiveStateT]): 194 """Prediction returned by an action or option model.""" 195 196 predicted_subjective_state: SubjectiveStateT 197 cumulative_reward: float 198 steps: int | None = None 199 terminated: bool = False 200 metadata: OpenPayload = field(default_factory=dict)
Prediction returned by an action or option model.
165@dataclass(slots=True, frozen=True) 166class OptionDescriptor: 167 """Lightweight metadata for an option.""" 168 169 option_id: OptionId 170 name: str 171 subtask_id: SubtaskId | None = None 172 metadata: OpenPayload = field(default_factory=dict)
Lightweight metadata for an option.
203@dataclass(slots=True, frozen=True) 204class PlanningUpdate(Generic[ActionT]): 205 """Outputs from one planning pass.""" 206 207 value_targets: Mapping[GeneralValueFunctionId, float] = field(default_factory=dict) 208 policy_targets: StructuredPayload = field(default_factory=dict) 209 search_statistics: StructuredPayload = field(default_factory=dict)
Outputs from one planning pass.
175@dataclass(slots=True, frozen=True) 176class PolicyDecision(Generic[ActionT]): 177 """Return type for reactive policy selection.""" 178 179 action: ActionT | None = None 180 option_id: OptionId | None = None 181 metadata: OpenPayload = field(default_factory=dict) 182 183 def __post_init__(self) -> None: 184 has_action = self.action is not None 185 has_option = self.option_id is not None 186 if has_action == has_option: 187 raise ValueError( 188 "PolicyDecision requires exactly one of action or option_id." 189 )
Return type for reactive policy selection.
153@dataclass(slots=True, frozen=True) 154class SubtaskSpec: 155 """A feature-grounded subtask description.""" 156 157 subtask_id: SubtaskId 158 name: str 159 feature_id: FeatureId 160 intensity: float = 1.0 161 general_value_function_id: GeneralValueFunctionId | None = None 162 metadata: OpenPayload = field(default_factory=dict)
A feature-grounded subtask description.
74@dataclass(slots=True, frozen=True) 75class TimeStep(Generic[ObservationT, InfoT]): 76 """One environment emission seen by the agent. 77 78 `TimeStep` is the object passed into `OaKAgent.step(...)`. It contains the 79 raw observation, scalar reward, episode-control flags, and optional 80 environment metadata. 81 """ 82 83 observation: ObservationT 84 reward: float 85 terminated: bool = False 86 truncated: bool = False 87 info: InfoT | None = None
One environment emission seen by the agent.
TimeStep is the object passed into OaKAgent.step(...). It contains the
raw observation, scalar reward, episode-control flags, and optional
environment metadata.
90@dataclass(slots=True, frozen=True) 91class Transition(Generic[ActionT, SubjectiveStateT, InfoT]): 92 """One subjective-state transition in agent terms. 93 94 `Transition` is constructed by the agent after two consecutive time steps. 95 Learners use it instead of the raw world stream so they can access both the 96 previous and next subjective state representations together with reward, 97 termination, and optional environment metadata. 98 """ 99 100 subjective_state: SubjectiveStateT 101 action: ActionT 102 reward: float 103 next_subjective_state: SubjectiveStateT 104 terminated: bool = False 105 option_id: OptionId | None = None 106 info: InfoT | None = None
One subjective-state transition in agent terms.
Transition is constructed by the agent after two consecutive time steps.
Learners use it instead of the raw world stream so they can access both the
previous and next subjective state representations together with reward,
termination, and optional environment metadata.
212@dataclass(slots=True, frozen=True) 213class UsageRecord: 214 """Usage evidence gathered for utility assessment.""" 215 216 kind: ComponentKind 217 component_id: ComponentId 218 amount: float = 1.0 219 metadata: OpenPayload = field(default_factory=dict)
Usage evidence gathered for utility assessment.
222@dataclass(slots=True, frozen=True) 223class UtilityRecord: 224 """Utility score for one architectural element.""" 225 226 kind: ComponentKind 227 component_id: ComponentId 228 utility: float 229 evidence: StructuredPayload = field(default_factory=dict)
Utility score for one architectural element.