Source code for training_classical_control.environment
from dataclasses import dataclass
import numpy as np
from gymnasium import Env
from gymnasium.wrappers import OrderEnforcing, PassiveEnvChecker, TimeLimit
from gymnasium.wrappers.render_collection import RenderCollection
from numpy.typing import NDArray
from training_classical_control.control import (
FeedbackController,
Observer,
RandomController,
)
from training_classical_control.inverted_pendulum import InvertedPendulumEnv
__all__ = [
"create_inverted_pendulum_environment",
"simulate_environment",
]
[docs]
def create_inverted_pendulum_environment(
render_mode: str | None = "rgb_array",
*,
max_steps: int = 500,
masspole: float = 0.1,
masscart: float = 1.0,
length: float = 1.0,
x_threshold: float = 3,
theta_threshold: float = 24,
force_max: float = 10.0,
) -> Env:
"""Creates instance of InvertedPendulumEnv with some wrappers
to ensure correctness, limit the number of steps and store rendered frames.
Args:
render_mode: Render mode for environment.
max_steps: Maximum number of steps in the environment before termination.
masspole: mass of the pole.
masscart: mass of the cart.
length: length of the pole.
force_max: maximum absolute value for force applied to Cart.
x_threshold: Threshold value for cart position.
theta_threshold: Threshold value for pole angle.
Returns:
Instantiated and wrapped environment.
"""
env = InvertedPendulumEnv(
masspole=masspole,
masscart=masscart,
length=length,
x_threshold=x_threshold,
theta_threshold=theta_threshold,
force_max=force_max,
render_mode=render_mode,
)
env = PassiveEnvChecker(env)
env = OrderEnforcing(env)
env = TimeLimit(env, max_steps)
if render_mode is not None:
env = RenderCollection(env)
return env
[docs]
@dataclass
class SimulationResults:
frames: list[NDArray]
observations: NDArray
estimated_observations: NDArray
actions: NDArray
[docs]
def simulate_environment(
env: Env,
*,
max_steps: int = 500,
controller: FeedbackController | None = None,
observer: Observer | None = None,
) -> SimulationResults:
if controller is None:
controller = RandomController(env)
observation, _ = env.reset()
actions = []
observations = [observation]
estimated_observations = []
frames = []
if observer is not None:
estimated_observation = observer.observe(observation)
estimated_observations.append(estimated_observation)
for _ in range(max_steps):
action = controller.act(observation)
observation, _, terminated, truncated, _ = env.step(action)
observations.append(observation)
actions.append(action)
if observer is not None:
estimated_observation = observer.observe(observation)
estimated_observations.append(estimated_observation)
# Check if we need to stop the simulation
if terminated or truncated:
if env.render_mode is not None:
frames = env.render()
env.reset()
break
env.close()
actions = np.stack(actions)
observations = np.stack(observations)
if estimated_observations:
estimated_observations = np.stack(estimated_observations)
return SimulationResults(
frames=frames,
observations=observations,
estimated_observations=estimated_observations,
actions=actions,
)