Source code for src.plangym.utils

"""Generic utilities for working with environments."""

import os

import gymnasium as gym
from gymnasium.spaces import Box
from gymnasium.wrappers.time_limit import TimeLimit
import numpy
from pyvirtualdisplay import Display
import cv2

try:
    from PIL import Image

    USE_PIL = True
except ImportError:  # pragma: no cover
    USE_PIL = False


[docs] def get_display(visible=False, size=(400, 400), **kwargs): """Start a virtual display.""" os.environ["PYVIRTUALDISPLAY_DISPLAYFD"] = "0" display = Display(visible=visible, size=size, **kwargs) display.start() return display
[docs] def remove_time_limit_from_spec(spec): """Remove the maximum time limit of an environment spec.""" if hasattr(spec, "max_episode_steps"): spec._max_episode_steps = spec.max_episode_steps spec.max_episode_steps = 1e100 if hasattr(spec, "max_episode_time"): spec._max_episode_time = spec.max_episode_time spec.max_episode_time = 1e100
[docs] def remove_time_limit(gym_env: gym.Env) -> gym.Env: """Remove the maximum time limit of the provided environment.""" if hasattr(gym_env, "spec") and gym_env.spec is not None: remove_time_limit_from_spec(gym_env.spec) if not isinstance(gym_env, gym.Wrapper): return gym_env for _ in range(5): try: if isinstance(gym_env, TimeLimit): return gym_env.env if isinstance(gym_env.env, gym.Wrapper) and isinstance(gym_env.env, TimeLimit): gym_env.env = gym_env.env.env # This is an ugly hack to make sure that we can remove the TimeLimit even # if somebody is crazy enough to apply three other wrappers on top of the TimeLimit elif isinstance(gym_env.env.env, gym.Wrapper) and isinstance( gym_env.env.env, TimeLimit, ): # pragma: no cover gym_env.env.env = gym_env.env.env.env elif isinstance(gym_env.env.env.env, gym.Wrapper) and isinstance( gym_env.env.env.env, TimeLimit, ): # pragma: no cover gym_env.env.env.env = gym_env.env.env.env.env else: # pragma: no cover break except AttributeError: break return gym_env
[docs] def process_frame_pil( frame: numpy.ndarray, width: int | None = None, height: int | None = None, mode: str = "RGB", ) -> numpy.ndarray: """Resize an RGB frame to a specified shape and mode. Use PIL to resize an RGB frame to a specified height and width \ or changing it to a different mode. Args: frame: Target numpy array representing the image that will be resized. width: Width of the resized image. height: Height of the resized image. mode: Passed to Image.convert. Returns: The resized frame that matches the provided width and height. """ mode = "L" if mode == "GRAY" else mode height = height or frame.shape[0] width = width or frame.shape[1] frame = Image.fromarray(frame) frame = frame.convert(mode).resize(size=(width, height)) return numpy.array(frame)
[docs] def process_frame_opencv( frame: numpy.ndarray, width: int | None = None, height: int | None = None, mode: str = "RGB", ) -> numpy.ndarray: """Resize an RGB frame to a specified shape and mode. Use OpenCV to resize an RGB frame to a specified height and width \ or changing it to a different mode. Args: frame: Target numpy array representing the image that will be resized. width: Width of the resized image. height: Height of the resized image. mode: Passed to cv2.cvtColor. Returns: The resized frame that matches the provided width and height. """ height = height or frame.shape[0] width = width or frame.shape[1] frame = cv2.resize(frame, (width, height)) if mode in {"GRAY", "L"}: frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) elif mode == "BGR": frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) return frame
[docs] def process_frame( frame: numpy.ndarray, width: int | None = None, height: int | None = None, mode: str = "RGB", ) -> numpy.ndarray: """Resize an RGB frame to a specified shape and mode. Use either PIL or OpenCV to resize an RGB frame to a specified height and width \ or changing it to a different mode. Args: frame: Target numpy array representing the image that will be resized. width: Width of the resized image. height: Height of the resized image. mode: Passed to either Image.convert or cv2.cvtColor. Returns: The resized frame that matches the provided width and height. """ func = process_frame_pil if USE_PIL else process_frame_opencv # pragma: no cover return func(frame, width, height, mode)
[docs] class GrayScaleObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs): """Convert the image observation from RGB to gray scale. Example: >>> import gymnasium as gym >>> from gymnasium.wrappers import GrayScaleObservation >>> env = gym.make("CarRacing-v2") >>> env.observation_space Box(0, 255, (96, 96, 3), uint8) >>> env = GrayScaleObservation(gym.make("CarRacing-v2")) >>> env.observation_space Box(0, 255, (96, 96), uint8) >>> env = GrayScaleObservation(gym.make("CarRacing-v2"), keep_dim=True) >>> env.observation_space Box(0, 255, (96, 96, 1), uint8) """ def __init__(self, env: gym.Env, keep_dim: bool = False): """Convert the image observation from RGB to gray scale. Args: env (Env): The environment to apply the wrapper keep_dim (bool): If `True`, a singleton dimension will be added, i.e. \ observations are of the shape AxBx1. Otherwise, they are of shape AxB. """ gym.utils.RecordConstructorArgs.__init__(self, keep_dim=keep_dim) gym.ObservationWrapper.__init__(self, env) self.keep_dim = keep_dim assert ( "Box" in self.observation_space.__class__.__name__ # works for both gym and gymnasium and len(self.observation_space.shape) == 3 # noqa: PLR2004 and self.observation_space.shape[-1] == 3 # noqa: PLR2004 ), f"Expected input to be of shape (..., 3), got {self.observation_space.shape}" obs_shape = self.observation_space.shape[:2] if self.keep_dim: self.observation_space = Box( low=0, high=255, shape=(obs_shape[0], obs_shape[1], 1), dtype=numpy.uint8 ) else: self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=numpy.uint8)
[docs] def observation(self, observation): """Convert the colour observation to greyscale. Args: observation: Color observations Returns: Grayscale observations """ import cv2 # noqa: PLC0415 observation = cv2.cvtColor(observation, cv2.COLOR_RGB2GRAY) if self.keep_dim: observation = numpy.expand_dims(observation, -1) return observation