Source code for src.plangym.api_tests

import copy
from itertools import product
import os
import warnings

import gymnasium as gym
import numpy
import pytest
from pyvirtualdisplay import Display

import plangym
from plangym.core import PlanEnv, PlangymEnv
from plangym.vectorization.env import VectorizedEnv


[docs] def generate_test_cases( names, env_class, n_workers_values=None, render_modes=None, obs_types=None, custom_tests=None, ) -> PlangymEnv: custom_tests = custom_tests or [] n_workers_vals = [None] if n_workers_values is None else n_workers_values names = [names] if isinstance(names, str) else names available_render_modes = ( [None] if os.getenv("SKIP_RENDER", False) else env_class.AVAILABLE_RENDER_MODES ) available_obs_types = ( [None] if os.getenv("SKIP_RENDER", False) else env_class.AVAILABLE_OBS_TYPES ) render_modes = available_render_modes if render_modes is None else render_modes obs_types = available_obs_types if obs_types is None else obs_types for i, (n_workers, obs_type, render_mode) in enumerate( product( n_workers_vals, obs_types, render_modes, ), ): name = names[i % len(names)] if isinstance(name, tuple): name = "-".join(name) def _make_env(): return plangym.make( name, n_workers=n_workers, obs_type=obs_type, render_mode=render_mode, ) yield _make_env yield from custom_tests
[docs] @pytest.fixture(scope="class") def batch_size() -> int: return 10
[docs] @pytest.fixture(scope="module") def display(): os.environ["PYVIRTUALDISPLAY_DISPLAYFD"] = "0" display = Display(visible=False, size=(400, 400)) display.start() yield display display.stop()
[docs] def step_tuple_test(env, obs, reward, terminal, info, dt=None): obs_is_array = isinstance(obs, numpy.ndarray) assert obs_is_array if env.OBS_IS_ARRAY else not obs_is_array assert obs.shape == env.obs_shape, (obs.shape, env.obs_shape) assert float(reward) + 1 == float(reward) + 1 assert isinstance(terminal, bool) assert isinstance(info, dict) assert "n_step" in info assert info["n_step"] <= int(dt * env.frameskip), (dt, env.frameskip, info.get("n_step", 0)) assert "dt" in info if dt is not None: assert info["dt"] == dt if env.return_image: assert "rgb" in info assert isinstance(info["rgb"], numpy.ndarray)
[docs] def step_batch_tuple_test(env, batch_size, observs, rewards, terminals, infos, dt): assert len(rewards) == batch_size assert len(terminals) == batch_size assert len(observs) == batch_size assert len(infos) == batch_size dts = dt if isinstance(dt, list | numpy.ndarray) else [dt] * batch_size for obs, reward, terminal, info, dt_ in zip(list(observs), rewards, terminals, infos, dts): step_tuple_test(env=env, obs=obs, reward=reward, terminal=terminal, info=info, dt=dt_)
[docs] class TestPlanEnv: CLASS_ATTRIBUTES = ("OBS_IS_ARRAY", "STATE_IS_ARRAY", "SINGLETON") PROPERTIES = ( "unwrapped", "obs_shape", "action_shape", "name", "frameskip", "autoreset", "delay_setup", "return_image", "img_shape", )
[docs] def test_init(self, env): pass
[docs] def test_repr(self, env): assert str(env) == repr(env)
# Test attributes and properties # ---------------------------------------------------------------------------------------------
[docs] def test_class_attributes(self, env): for name in self.CLASS_ATTRIBUTES: assert hasattr(env.__class__, name), f"Env {env.name} does not have attribute {name}" isinstance(getattr(env.__class__, name), bool)
[docs] def test_has_properties(self, env): for name in self.PROPERTIES: assert hasattr(env, name), f"Env {env.name} does not have property {name}"
[docs] def test_name(self, env): assert isinstance(env.name, str)
[docs] def test_obs_shape(self, env): assert hasattr(env, "obs_shape") assert isinstance(env.obs_shape, tuple) if env.obs_shape: for val in env.obs_shape: assert isinstance(val, int) obs, _info = env.reset(return_state=False) assert obs.shape == env.obs_shape, (obs.shape, env.obs_shape) obs, *_ = env.step(env.sample_action()) assert obs.shape == env.obs_shape, (obs.shape, env.obs_shape)
[docs] def test_img_shape(self, env): assert hasattr(env, "img_shape") if env.img_shape is None: return assert isinstance(env.img_shape, tuple) if env.img_shape: for val in env.img_shape: assert isinstance(val, int)
[docs] def test_action_shape(self, env): assert hasattr(env, "action_shape") assert isinstance(env.action_shape, tuple) if env.action_shape: for val in env.action_shape: assert isinstance(val, int)
[docs] def test_unwrapped(self, env): assert isinstance(env.unwrapped, PlanEnv)
[docs] @pytest.mark.skipif(os.getenv("SKIP_RENDER", False), reason="No display in CI.") @pytest.mark.parametrize("return_image", [True, False]) def test_return_image(self, env, return_image): assert isinstance(env.return_image, bool) if isinstance(env, VectorizedEnv): env.plan_env._return_image = return_image else: env._return_image = return_image _ = env.reset() *_, info = env.step(env.sample_action()) if env.return_image: assert "rgb" in info
# Test public API functions # ---------------------------------------------------------------------------------------------
[docs] def test_sample_action(self, env): action = env.sample_action() if env.action_shape: assert action.shape == env.action_shape
[docs] def test_get_state(self, env): state_reset, _obs, _info = env.reset() state = env.get_state() state_is_array = isinstance(state, numpy.ndarray) assert state_is_array if env.STATE_IS_ARRAY else not state_is_array if state_is_array and not env.SINGLETON: assert (state == state_reset).all(), f"original: {state} env: {env.get_state()}"
[docs] def test_set_state(self, env): env.reset() state = env.get_state() env.step(env.sample_action()) env.set_state(state) if env.STATE_IS_ARRAY: env_state = env.get_state() assert state.shape == env_state.shape if state.dtype is object and not env.SINGLETON: assert (state == env_state).all(), (state, env.get_state())
[docs] def test_reset(self, env): _ = env.reset(return_state=False) state, obs, info = env.reset(return_state=True) if env.return_image: assert "rgb" in info assert isinstance(info["rgb"], numpy.ndarray) assert info["rgb"].shape == env.img_shape state_is_array = isinstance(state, numpy.ndarray) obs_is_array = isinstance(obs, numpy.ndarray) assert isinstance(info, dict), info assert state_is_array if env.STATE_IS_ARRAY else not state_is_array assert obs_is_array if env.OBS_IS_ARRAY else not obs_is_array
[docs] @pytest.mark.parametrize("state", [None, True]) @pytest.mark.parametrize("return_state", [None, True, False]) def test_step(self, env, state, return_state, dt=1): _state, *_ = env.reset(return_state=True) if state is not None: state = _state action = env.sample_action() data = env.step(action, dt=dt, state=state, return_state=return_state) *new_state, obs, reward, terminal, _truncated, info = data assert isinstance(data, tuple) # Test return state works correctly should_return_state = state is not None if return_state is None else return_state if should_return_state: assert len(new_state) == 1 new_state = new_state[0] state_is_array = isinstance(new_state, numpy.ndarray) assert state_is_array if env.STATE_IS_ARRAY else not state_is_array if state_is_array: assert _state.shape == new_state.shape if not env.SINGLETON and env.STATE_IS_ARRAY: curr_state = env.get_state() assert new_state.shape == curr_state.shape assert (new_state == curr_state).all(), ( f"original: {new_state[new_state != curr_state]} " f"env: {curr_state[new_state != curr_state]}" ) else: assert len(new_state) == 0 step_tuple_test(env, obs, reward, terminal, info, dt=dt)
[docs] @pytest.mark.parametrize("states", [None, True, "None_list"]) @pytest.mark.parametrize("return_state", [None, True, False]) def test_step_batch(self, env, states, return_state, batch_size): dt = 1 state, *_ = env.reset() if states == "None_list": states = [None] * batch_size elif states: states = [copy.deepcopy(state) for _ in range(batch_size)] actions = [env.sample_action() for _ in range(batch_size)] data = env.step_batch(actions, dt=dt, states=states, return_state=return_state) *new_states, observs, rewards, terminals, _truncated, infos = data assert isinstance(data, tuple) # Test return state works correctly default_returns_state = ( return_state is None and isinstance(states, list) and states[0] is not None ) should_return_state = return_state or default_returns_state if should_return_state: assert len(new_states) == 1 new_states = new_states[0] # TODO: update check when returning batch arrays is available assert isinstance(new_states, list) state_is_array = isinstance(new_states[0], numpy.ndarray) assert state_is_array if env.STATE_IS_ARRAY else not state_is_array if env.STATE_IS_ARRAY: assert state.shape == new_states[0].shape else: assert len(new_states) == 0, (len(new_states), should_return_state, return_state) step_batch_tuple_test( env=env, batch_size=batch_size, observs=observs, rewards=rewards, terminals=terminals, infos=infos, dt=dt, )
[docs] def test_step_dt_values(self, env, dt=3, return_state=None): state = None _state, *_ = env.reset(return_state=True) action = env.sample_action() data = env.step(action, dt=dt, state=state, return_state=return_state) *new_state, obs, reward, terminal, _truncated, info = data assert isinstance(data, tuple) assert len(new_state) == 0 step_tuple_test(env, obs, reward, terminal, info, dt=dt)
[docs] @pytest.mark.parametrize("dt", [3, "array"]) def test_step_batch_dt_values(self, env, dt, batch_size, states=None, return_state=None): rng = numpy.random.default_rng() dt = dt if dt != "array" else rng.integers(1, 4, batch_size).astype(int) _state, *_ = env.reset() actions = [env.sample_action() for _ in range(batch_size)] data = env.step_batch(actions, dt=dt, states=states, return_state=return_state) *new_states, observs, rewards, terminals, _truncated, infos = data assert isinstance(data, tuple) assert len(new_states) == 0, (len(new_states), return_state) step_batch_tuple_test( env=env, batch_size=batch_size, observs=observs, rewards=rewards, terminals=terminals, infos=infos, dt=dt, )
[docs] @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.parametrize("delay_setup", [False, True]) def test_clone_and_close(self, env, delay_setup): if not env.SINGLETON: env.delay_setup = delay_setup clone = env.clone() if clone.delay_setup: clone.reset() del clone clone = env.clone() if clone.delay_setup: clone.setup() clone.close()
[docs] @pytest.mark.skipif(os.getenv("SKIP_RENDER", False), reason="No display in CI.") def test_get_image(self, env): img = env.get_image() if img is not None: assert isinstance(img, numpy.ndarray) assert len(img.shape) == 2 or len(img.shape) == 3
[docs] class TestPlangymEnv: CLASS_ATTRIBUTES = ("AVAILABLE_OBS_TYPES", "DEFAULT_OBS_TYPE") PROPERTIES = ( "gym_env", "obs_shape", "obs_type", "observation_space", "action_shape", "action_space", "reward_range", "metadata", "render_mode", "remove_time_limit", "name", "frameskip", "autoreset", "delay_setup", "return_image", )
[docs] def test_class_attributes(self, env): for name in self.CLASS_ATTRIBUTES: assert hasattr(env.__class__, name), f"Env {env.name} does not have attribute {name}" isinstance(getattr(env.__class__, name), bool)
[docs] def test_has_properties(self, env): for name in self.PROPERTIES: assert hasattr(env, name), f"Env {env.name} does not have property {name}"
[docs] def test_obs_type(self, env): assert isinstance(env.obs_type, str) assert env.obs_type in env.AVAILABLE_OBS_TYPES assert env.DEFAULT_OBS_TYPE in env.AVAILABLE_OBS_TYPES, ( str(env.DEFAULT_OBS_TYPE), env.AVAILABLE_OBS_TYPES, )
[docs] def test_obvervation_space(self, env): assert hasattr(env, "observation_space") # assert isinstance(env.observation_space, gym.Space), ( # env.observation_space, # env.DEFAULT_OBS_TYPE, # ) assert env.observation_space.shape == env.obs_shape if env.observation_space.shape: obs, *_info = env.reset(return_state=False) obs_shape = env.observation_space.shape assert obs_shape == obs.shape, (obs_shape, obs.shape)
[docs] def test_action_space(self, env): assert hasattr(env, "action_space") assert isinstance(env.action_space, gym.Space) assert env.action_space.shape == env.action_shape if env.action_space.shape: assert env.action_space.shape == env.sample_action().shape
[docs] @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") def test_gym_env(self, env): assert hasattr(env.gym_env, "reset") assert hasattr(env.gym_env, "step") if not isinstance(env, VectorizedEnv) and not env.SINGLETON: env.close() env.gym_env
[docs] def test_reward_range(self, env): env.reward_range
[docs] @pytest.mark.parametrize("delay_setup", [True, False]) def test_delay_setup(self, env, delay_setup): if env.SINGLETON or isinstance(env, VectorizedEnv): return new_env = env.clone(delay_setup=delay_setup) assert new_env._gym_env is None if delay_setup else new_env._gym_env is not None assert env.gym_env is not None
[docs] def test_has_metadata(self, env): assert hasattr(env, "metadata")
[docs] def test_render_mode(self, env): assert hasattr(env, "render_mode") if env.render_mode is not None: assert isinstance(env.render_mode, str) assert env.render_mode in env.AVAILABLE_RENDER_MODES
[docs] def test_remove_time_limit(self, env): assert isinstance(env.remove_time_limit, bool) if env.remove_time_limit and not env._wrappers: assert "TimeLimit" not in str(env.gym_env), env.gym_env
[docs] def test_seed(self, env): env.seed() env.seed(1)
[docs] def test_terminal(self, env): if env.autoreset: if not env.SINGLETON: env.setup() env.reset() if hasattr(env, "render_mode") and env.render_mode in {"human", "rgb_array"}: return env.step_with_dt(env.sample_action(), dt=1000)
[docs] @pytest.mark.skipif(os.getenv("SKIP_RENDER", False), reason="No display in CI.") def test_render(self, env, display): with warnings.catch_warnings(): # warnings.simplefilter("ignore") env.render()
[docs] def test_wrap_environment(self, env): if isinstance(env, VectorizedEnv): return from gym.wrappers.transform_reward import TransformReward # noqa: PLC0415 wrappers = [(TransformReward, {"f": lambda x: x})] env.apply_wrappers(wrappers) assert isinstance(env.gym_env, TransformReward) env._gym_env = env.gym_env.env wrappers = [(TransformReward, [lambda x: x])] env.apply_wrappers(wrappers) assert isinstance(env.gym_env, TransformReward) env._gym_env = env.gym_env.env wrappers = [(TransformReward, lambda x: x)] env.apply_wrappers(wrappers) assert isinstance(env.gym_env, TransformReward) env._gym_env = env.gym_env.env
[docs] class TestVideogameEnv: """Test the VideogameEnv class."""
[docs] def test_ram(self, env): """Test the ram property.""" assert hasattr(env, "get_ram") assert isinstance(env.get_ram(), numpy.ndarray)