Source code for src.plangym.videogames.nes

"""Environment for playing Mario bros using gym-super-mario-bros."""

from typing import Any, TypeVar

import gymnasium as gym
import numpy

from plangym.videogames.env import VideogameEnv

# actions for the simple run right environment
RIGHT_ONLY = [
    ["NOOP"],
    ["right"],
    ["right", "A"],
    ["right", "B"],
    ["right", "A", "B"],
]


# actions for very simple movement
SIMPLE_MOVEMENT = [
    ["NOOP"],
    ["right"],
    ["right", "A"],
    ["right", "B"],
    ["right", "A", "B"],
    ["A"],
    ["left"],
]


# actions for more complex movement
COMPLEX_MOVEMENT = [
    ["NOOP"],
    ["right"],
    ["right", "A"],
    ["right", "B"],
    ["right", "A", "B"],
    ["A"],
    ["left"],
    ["left", "A"],
    ["left", "B"],
    ["left", "A", "B"],
    ["down"],
    ["up"],
]

ObsType = TypeVar("ObsType")
ActType = TypeVar("ActType")
RenderFrame = TypeVar("RenderFrame")


[docs] class NESWrapper: """A wrapper for the NES environment.""" def __init__(self, wrapped): """Initialize the NESWrapper.""" self._wrapped = wrapped
[docs] def __getattr__(self, name): """Get an attribute from the wrapped object.""" return getattr(self._wrapped, name)
[docs] def __setattr__(self, name, value): """Set an attribute on the wrapped object.""" if name == "_wrapped": super().__setattr__(name, value) else: setattr(self._wrapped, name, value) # pragma: no cover
[docs] def __delattr__(self, name): """Delete an attribute from the wrapped object.""" delattr(self._wrapped, name) # pragma: no cover
[docs] def step( self, action: ActType ) -> tuple[gym.core.WrapperObsType, gym.core.SupportsFloat, bool, bool, dict[str, Any]]: """Modify the :attr:`env` after calling :meth:`step` using :meth:`self.observation`.""" observation, reward, terminated, info = self._wrapped.step(action) truncated = False return self.observation(observation), reward, terminated, truncated, info
[docs] def reset( self, *, seed: int | None = None, # noqa: ARG002 options: dict[str, Any] | None = None, # noqa: ARG002 ) -> tuple[gym.core.WrapperObsType, dict[str, Any]]: """Modify the :attr:`env` after calling :meth:`reset`, returning a modified observation.""" obs = self.env.reset() info = {} return self.observation(obs), info
[docs] def observation(self, observation: ObsType) -> gym.core.WrapperObsType: """Return a modified observation. Args: observation: The :attr:`env` observation Returns: The modified observation """ return observation
[docs] class JoypadSpace(gym.Wrapper): """An environment wrapper to convert binary to discrete action space.""" # a mapping of buttons to binary values _button_map = { "right": 0b10000000, "left": 0b01000000, "down": 0b00100000, "up": 0b00010000, "start": 0b00001000, "select": 0b00000100, "B": 0b00000010, "A": 0b00000001, "NOOP": 0b00000000, }
[docs] @classmethod def buttons(cls) -> list: """Return the buttons that can be used as actions.""" return list(cls._button_map.keys())
def __init__(self, env: gym.Env, actions: list): """Initialize a new binary to discrete action space wrapper. Args: env: the environment to wrap actions: an ordered list of actions (as lists of buttons). The index of each button list is its discrete coded value Returns: None """ super().__init__(env) # create the new action space self.action_space = gym.spaces.Discrete(len(actions)) # create the action map from the list of discrete actions self._action_map = {} self._action_meanings = {} # iterate over all the actions (as button lists) for action, button_list in enumerate(actions): # the value of this action's bitmap byte_action = 0 # iterate over the buttons in this button list for button in button_list: byte_action |= self._button_map[button] # set this action maps value to the byte action value self._action_map[action] = byte_action self._action_meanings[action] = " ".join(button_list)
[docs] def step(self, action): """Take a step using the given action. Args: action (int): the discrete action to perform Returns: a tuple of: - (numpy.ndarray) the state as a result of the action - (float) the reward achieved by taking the action - (bool) a flag denoting whether the episode has ended - (dict) a dictionary of extra information """ # take the step and record the output return self.env.step(self._action_map[action])
# def reset(self, *, seed=None, options=None): # """Reset the environment and return the initial observation.""" # return self.env.reset(), {}
[docs] def get_keys_to_action(self): """Return the dictionary of keyboard keys to actions.""" # get the old mapping of keys to actions old_keys_to_action = self.env.unwrapped.get_keys_to_action() # invert the keys to action mapping to lookup key combos by action action_to_keys = {v: k for k, v in old_keys_to_action.items()} # create a new mapping of keys to actions keys_to_action = {} # iterate over the actions and their byte values in this mapper for action, byte in self._action_map.items(): # get the keys to press for the action keys = action_to_keys[byte] # set the keys value in the dictionary to the current discrete act keys_to_action[keys] = action return keys_to_action
[docs] def get_action_meanings(self): """Return a list of actions meanings.""" actions = sorted(self._action_meanings.keys()) return [self._action_meanings[action] for action in actions]
[docs] class NesEnv(VideogameEnv): """Environment for working with the NES-py emulator.""" @property def nes_env(self) -> "NESEnv": # noqa: F821 """Access the underlying NESEnv.""" return self.gym_env.unwrapped
[docs] def get_image(self) -> numpy.ndarray: """Return a numpy array containing the rendered view of the environment. Square matrices are interpreted as a greyscale image. Three-dimensional arrays are interpreted as RGB images with channels (Height, Width, RGB) """ return self.gym_env.screen.copy()
[docs] def get_ram(self) -> numpy.ndarray: """Return a copy of the emulator environment.""" return self.nes_env.ram.copy()
[docs] def get_state(self, state: numpy.ndarray | None = None) -> numpy.ndarray: """Recover the internal state of the simulation. A state must completely describe the Environment at a given moment. """ return self.gym_env.get_state(state)
[docs] def set_state(self, state: numpy.ndarray) -> None: """Set the internal state of the simulation. Args: state: Target state to be set in the environment. Returns: None """ self.gym_env.set_state(state)
[docs] def close(self) -> None: """Close the underlying :class:`gym.Env`.""" if self.nes_env._env is None: return try: super().close() except ValueError: # pragma: no cover pass
[docs] def __del__(self): """Tear down the environment.""" try: self.close() except ValueError: # pragma: no cover pass
[docs] def render(self, mode="rgb_array"): # noqa: ARG002 """Render the environment.""" return self.gym_env.screen.copy()
[docs] class MarioEnv(NesEnv): """Interface for using gym-super-mario-bros in plangym.""" AVAILABLE_OBS_TYPES = {"coords", "rgb", "grayscale", "ram"} MOVEMENTS = { "complex": COMPLEX_MOVEMENT, "simple": SIMPLE_MOVEMENT, "right": RIGHT_ONLY, } def __init__( self, name: str, movement_type: str = "simple", original_reward: bool = False, **kwargs, ): """Initialize a MarioEnv. Args: name: Name of the environment. movement_type: One of {complex|simple|right} original_reward: If False return a custom reward based on mario position and level. **kwargs: passed to super().__init__. """ self._movement_type = movement_type self._original_reward = original_reward super().__init__(name=name, **kwargs)
[docs] def get_state(self, state: numpy.ndarray | None = None) -> numpy.ndarray: """Recover the internal state of the simulation. A state must completely describe the Environment at a given moment. """ state = numpy.empty(250288, dtype=numpy.byte) if state is None else state state[-2:] = 0 # Some states use the last two bytes. Set to zero by default. return super().get_state(state)
[docs] def init_gym_env(self) -> gym.Env: """Initialize the :class:`NESEnv`` instance that the current class is wrapping.""" from gym_super_mario_bros import make # noqa: PLC0415 from gym_super_mario_bros.actions import COMPLEX_MOVEMENT # noqa: PLC0415 env = make(self.name) gym_env = NESWrapper(JoypadSpace(env.unwrapped, COMPLEX_MOVEMENT)) gym_env.reset() return gym_env
[docs] def _update_info(self, info: dict[str, Any]) -> dict[str, Any]: info["player_state"] = self.nes_env._player_state info["area"] = self.nes_env._area info["left_x_position"] = self.nes_env._left_x_position info["is_stage_over"] = self.nes_env._is_stage_over info["is_dying"] = self.nes_env._is_dying info["is_dead"] = self.nes_env._is_dead info["y_pixel"] = self.nes_env._y_pixel info["y_viewport"] = self.nes_env._y_viewport info["x_position_last"] = self.nes_env._x_position_last info["in_pipe"] = (info["player_state"] == 0x02) or (info["player_state"] == 0x03) # noqa: PLR2004 return info
[docs] def _get_info( self, ): info = { "x_pos": 0, "y_pos": 0, "world": 0, "stage": 0, "life": 0, "coins": 0, "flag_get": False, "in_pipe": False, } return self._update_info(info)
[docs] def get_coords_obs( self, obs: numpy.ndarray, info: dict[str, Any] | None = None, **kwargs, # noqa: ARG002 ) -> numpy.ndarray: """Return the information contained in info as an observation if obs_type == "info".""" if self.obs_type == "coords": info = info or self._get_info() obs = numpy.array( [ info.get("x_pos", 0), info.get("y_pos", 0), info.get("world" * 10, 0), info.get("stage", 0), info.get("life", 0), int(info.get("flag_get", 0)), info.get("coins", 0), ], ) return obs
[docs] def process_reward(self, reward, info, **kwargs) -> float: # noqa: ARG002 """Return a custom reward based on the x, y coordinates and level mario is in.""" if not self._original_reward: world = int(info.get("world", 0)) stage = int(info.get("stage", 0)) x_pos = int(info.get("x_pos", 0)) reward = ( (world * 25000) + (stage * 5000) + x_pos + 10 * int(bool(info.get("in_pipe", 0))) + 100 * int(bool(info.get("flag_get", 0))) # + (abs(info["x_pos"] - info["x_position_last"])) ) return reward
[docs] def process_terminal(self, terminal, info, **kwargs) -> bool: # noqa: ARG002 """Return True if terminal or mario is dying.""" return terminal or info.get("is_dying", False) or info.get("is_dead", False)
[docs] def process_info(self, info, **kwargs) -> dict[str, Any]: # noqa: ARG002 """Add additional data to the info dictionary.""" return self._update_info(info)