Source code for training_ml_control.environments.utils

from collections import defaultdict
from dataclasses import dataclass

import networkx as nx
import numpy as np
from gymnasium import Env
from gymnasium.envs.classic_control.pendulum import PendulumEnv
from gymnasium.wrappers import OrderEnforcing, PassiveEnvChecker, TimeLimit
from gymnasium.wrappers.render_collection import RenderCollection
from numpy.typing import NDArray

from training_ml_control.control import FeedbackController, Observer, RandomController
from training_ml_control.environments.cart import CartEnv
from training_ml_control.environments.grid_world import GridWorldEnv
from training_ml_control.environments.inverted_pendulum import InvertedPendulumEnv

__all__ = [
    "create_inverted_pendulum_environment",
    "create_grid_world_environment",
    "create_cart_environment",
    "create_pendulum_environment",
    "simulate_environment",
    "value_iteration",
    "compute_best_path_and_actions_from_values",
]


[docs] def create_cart_environment( render_mode: str | None = "rgb_array", *, max_steps: int = 200, goal_velocity: float = 0, max_position: float = 10, max_speed: float = 10, max_force: float = 10, goal_position: float = 9.0, ) -> Env: """Creates instance of CartEnv with some wrappers to ensure correctness, limit the number of steps and store rendered frames. """ env = CartEnv( render_mode=render_mode, goal_velocity=goal_velocity, max_position=max_position, max_speed=max_speed, max_force=max_force, goal_position=goal_position, ) env = TimeLimit(env, max_steps) # env = PassiveEnvChecker(env) env = OrderEnforcing(env) if render_mode is not None: env = RenderCollection(env) return env
[docs] def create_pendulum_environment( render_mode: str | None = "rgb_array", *, max_steps: int = 200, ) -> Env: """Creates instance of PendulumEnv with some wrappers to ensure correctness, limit the number of steps and store rendered frames. """ env = PendulumEnv(render_mode=render_mode) env = TimeLimit(env, max_steps) # env = PassiveEnvChecker(env) env = OrderEnforcing(env) if render_mode is not None: env = RenderCollection(env) return env
[docs] def create_grid_world_environment( render_mode: str | None = "rgb_array", *, max_steps: int = 20, ) -> Env: """Creates instance of GridWorldEnv with some wrappers to ensure correctness, limit the number of steps and store rendered frames. """ env = GridWorldEnv(render_mode=render_mode, max_steps=max_steps) # env = PassiveEnvChecker(env) env = OrderEnforcing(env) if render_mode is not None: env = RenderCollection(env) return env
[docs] def create_inverted_pendulum_environment( render_mode: str | None = "rgb_array", *, max_steps: int = 500, masspole: float | None = None, masscart: float | None = None, length: float | None = None, x_threshold: float = 3, theta_initial: float = 0.0, theta_threshold: float = 24, force_max: float = 10.0, ) -> Env: """Creates instance of InvertedPendulumEnv with some wrappers to ensure correctness, limit the number of steps and store rendered frames. Args: render_mode: Render mode for environment. max_steps: Maximum number of steps in the environment before termination. masspole: mass of the pole. masscart: mass of the cart. length: length of the pole. force_max: maximum absolute value for force applied to Cart. x_threshold: Threshold value for cart position. theta_threshold: Threshold value for pole angle. Returns: Instantiated and wrapped environment. """ env = InvertedPendulumEnv( masspole=masspole, masscart=masscart, length=length, x_threshold=x_threshold, theta_threshold=theta_threshold, theta_initial=theta_initial, force_max=force_max, render_mode=render_mode, ) env = PassiveEnvChecker(env) env = OrderEnforcing(env) env = TimeLimit(env, max_steps) if render_mode is not None: env = RenderCollection(env) return env
[docs] @dataclass class SimulationResults: frames: list[NDArray] observations: NDArray estimated_observations: NDArray actions: NDArray
[docs] def simulate_environment( env: Env, *, max_steps: int = 500, controller: FeedbackController | None = None, observer: Observer | None = None, seed: int = 16, ) -> SimulationResults: if controller is None: controller = RandomController(env) observation, _ = env.reset(seed=seed) actions = [] observations = [observation] estimated_observations = [] frames = [] if observer is not None: estimated_observation = observer.observe(observation) estimated_observations.append(estimated_observation) for _ in range(max_steps): action = controller.act(observation) observation, _, terminated, truncated, _ = env.step(action) observations.append(observation) actions.append(action) if observer is not None: estimated_observation = observer.observe(observation) estimated_observations.append(estimated_observation) # Check if we need to stop the simulation if terminated or truncated: break if env.render_mode is not None: frames = env.render() env.reset() env.close() actions = np.stack(actions) observations = np.stack(observations) if estimated_observations: estimated_observations = np.stack(estimated_observations) return SimulationResults( frames=frames, observations=observations, estimated_observations=estimated_observations, actions=actions, )
[docs] def value_iteration(G: nx.DiGraph) -> dict[tuple[int, int], float]: values = defaultdict(lambda: 0.0) delta = np.inf while delta > 0.0: delta = 0.0 Q = defaultdict(lambda: defaultdict(lambda: 0.0)) visited_nodes = [] for node in G.nodes: if node in visited_nodes: continue visited_nodes.append(node) next_nodes = list(G.successors(node)) if not next_nodes: continue for next_node in G.successors(node): new_value = 1.0 + values[next_node] Q[node][next_node] = new_value min_q = min(Q[node].values()) delta = max(delta, abs(values[node] - min_q)) values[node] = min_q return values
[docs] def compute_best_path_and_actions_from_values( G: nx.DiGraph, start_node: tuple[int, int], target_node: tuple[int, int], values: dict[tuple[int, int], float], ) -> tuple[list[tuple[int, int]], list[int]]: best_path = [start_node] actions = [] current_node = start_node while current_node != target_node: next_nodes = list(G.successors(current_node)) if not next_nodes: break current_values = defaultdict(lambda: 0.0) for next_node in G.successors(current_node): value = 1.0 + values[next_node] current_values[next_node] = value best_next_node = min(current_values, key=current_values.get) best_path.append(current_node) action = G.edges[(current_node, best_next_node)].get("action") actions.append(action) current_node = best_next_node return best_path, actions