"""Handle parallelization for ``plangym.Environment`` that allows vectorized steps."""
import atexit
import multiprocessing
import sys
import traceback
import numpy
from plangym.core import PlanEnv
from plangym.vectorization.env import VectorizedEnv
[docs]
class ExternalProcess:
"""Step environment in a separate process for lock free paralellism.
The environment will be created in the external process by calling the
specified callable. This can be an environment class, or a function
creating the environment and potentially wrapping it. The returned
environment should not access global variables.
Args:
constructor: Callable that creates and returns an OpenAI gym environment.
Attributes:
observation_space: The cached observation space of the environment.
action_space: The cached action space of the environment.
..notes:
This is mostly a copy paste from
https://github.com/tensorflow/agents/blob/master/agents/tools/wrappers.py,
that lets us set and read the environment state.
"""
# Message types for communication via the pipe.
_ACCESS = 1
_CALL = 2
_RESULT = 3
_EXCEPTION = 4
_CLOSE = 5
def __init__(self, constructor):
"""Initialize a :class:`ExternalProcess`.
Args:
constructor: Callable that returns the target environment that will be parallelized.
"""
self._conn, conn = multiprocessing.Pipe()
self._process = multiprocessing.Process(target=self._worker, args=(constructor, conn))
atexit.register(self.close)
self._process.start()
self._observ_space = None
self._action_space = None
@property
def observation_space(self):
"""Return the observation space of the internal environment."""
if not self._observ_space:
self._observ_space = self.__getattr__("observation_space") # noqa: PLC2801
return self._observ_space
@property
def action_space(self):
"""Return the action space of the internal environment."""
if not self._action_space:
self._action_space = self.__getattr__("action_space") # noqa: PLC2801
return self._action_space
[docs]
def __getattr__(self, name):
"""Request an attribute from the environment.
Note that this involves communication with the external process, so it can \
be slow.
Args:
name: Attribute to access.
Returns:
Value of the attribute.
"""
self._conn.send((self._ACCESS, name))
return self._receive()
[docs]
def call(self, name, *args, **kwargs):
"""Asynchronously call a method of the external environment.
Args:
name: Name of the method to call.
*args: Positional arguments to forward to the method.
**kwargs: Keyword arguments to forward to the method.
Returns:
Promise object that blocks and provides the return value when called.
"""
payload = name, args, kwargs
self._conn.send((self._CALL, payload))
return self._receive
[docs]
def close(self):
"""Send a close message to the external process and join it."""
try:
self._conn.send((self._CLOSE, None))
self._conn.close()
except OSError:
# The connection was already closed.
pass
self._process.join()
[docs]
def set_state(self, state, blocking=True):
"""Set the state of the internal environment."""
promise = self.call("set_state", state)
return promise() if blocking else promise
[docs]
def step_batch(
self,
actions,
states=None,
dt: numpy.ndarray | int = None,
return_state: bool | None = None,
blocking=True,
):
"""Vectorized version of the ``step`` method.
It allows to step a vector of states and actions. The signature and \
behaviour is the same as ``step``, but taking a list of states, actions \
and dts as input.
Args:
actions: Iterable containing the different actions to be applied.
states: Iterable containing the different states to be set.
dt: int or array containing the frameskips that will be applied.
blocking: If True, execute sequentially.
return_state: Whether to return the state in the returned tuple. \
If None, `step` will return the state if `state` was passed as a parameter.
Returns:
if states is None returns ``(observs, rewards, ends, infos)``
else returns ``(new_states, observs, rewards, ends, infos)``
"""
promise = self.call("step_batch", actions, states, dt, return_state)
return promise() if blocking else promise
[docs]
def step(self, action, state=None, dt: int = 1, blocking=True):
"""Step the environment.
Args:
action: The action to apply to the environment.
state: State to be set on the environment before stepping it.
dt: Number of consecutive times that action will be applied.
blocking: Whether to wait for the result.
Returns:
Transition tuple when blocking, otherwise callable that returns the \
transition tuple.
"""
promise = self.call("step", action, state, dt)
return promise() if blocking else promise
[docs]
def reset(self, blocking=True, return_states: bool = False):
"""Reset the environment.
Args:
blocking: Whether to wait for the result.
return_states: If true return also the initial state of the environment.
Returns:
New observation when blocking, otherwise callable that returns the new \
observation.
"""
promise = self.call("reset", return_state=return_states)
return promise() if blocking else promise
[docs]
def _receive(self):
"""Wait for a message from the worker process and return its payload.
Raises
Exception: An exception was raised inside the worker process.
KeyError: The received message is of an unknown type.
Returns
Payload object of the message.
"""
message, payload = self._conn.recv()
# Re-raise exceptions in the main process.
if message == self._EXCEPTION:
stacktrace = payload # pragma: no cover
raise Exception(stacktrace) # pragma: no cover
if message == self._RESULT:
return payload
raise KeyError(f"Received unexpected message {message}") # pragma: no cover
[docs]
def _worker(self, constructor, conn):
"""Wait for actions and send back environment results.
Args:
constructor: Constructor for the OpenAI Gym environment.
conn: Connection for communication to the main process.
Raises:
KeyError: When receiving a message of unknown type.
"""
try:
env = constructor()
env.reset()
while True:
try:
# Only block for short times to have keyboard exceptions be raised.
if not conn.poll(0.1):
continue
message, payload = conn.recv()
except (EOFError, KeyboardInterrupt): # pragma: no cover
break
if message == self._ACCESS:
name = payload
result = getattr(env, name)
conn.send((self._RESULT, result))
continue
if message == self._CALL:
name, args, kwargs = payload
result = getattr(env, name)(*args, **kwargs)
conn.send((self._RESULT, result))
continue
if message == self._CLOSE:
assert payload is None
break # pragma: no cover
raise KeyError(
f"Received message of unknown type {message}",
) # pragma: no cover
except Exception: # pragma: no cover # pylint: disable=broad-except
stacktrace = "".join(traceback.format_exception(*sys.exc_info()))
conn.send((self._EXCEPTION, stacktrace))
conn.close()
[docs]
class BatchEnv:
"""Combine multiple environments to step them in batch.
It is mostly a copy paste from \
https://github.com/tensorflow/agents/blob/master/agents/tools/wrappers.py \
that also allows to set and get the states.
To step environments in parallel, environments must support a \
``blocking=False`` argument to their step and reset functions that \
makes them return callables instead to receive the result at a later time.
Args:
envs: List of environments.
blocking: Step environments after another rather than in parallel.
Raises:
ValueError: Environments have different observation or action spaces.
"""
def __init__(self, envs, blocking):
"""Initialize a :class:`BatchEnv`.
Args:
envs: List of :class:`ExternalProcess` that contain the target environment.
blocking: If ``True`` perform the steps sequentially. If ``False`` step \
the environments in parallel.
"""
self._envs = envs
self._blocking = blocking
[docs]
def __len__(self) -> int:
"""Return the number of combined environments."""
return len(self._envs)
[docs]
def __getitem__(self, index):
"""Access an underlying environment by index."""
return self._envs[index]
[docs]
def __getattr__(self, name):
"""Forward unimplemented attributes to one of the original environments.
Args:
name: Attribute that was accessed.
Returns:
Value behind the attribute name one of the wrapped environments.
"""
return getattr(self._envs[0], name)
[docs]
def make_transitions(
self,
actions,
states=None,
dt: numpy.ndarray | int = 1,
return_state: bool | None = None,
):
"""Implement the logic for stepping the environment in parallel."""
results = []
no_states = states is None or states[0] is None
if return_state is None:
_return_state = not no_states
else:
_return_state = return_state
chunks = ParallelEnv.batch_step_data(
actions=actions,
states=states,
dt=dt,
batch_size=len(self._envs),
)
for env, states_batch, actions_batch, _dt in zip(self._envs, *chunks):
result = env.step_batch(
actions=actions_batch,
states=states_batch,
dt=_dt,
blocking=self._blocking,
return_state=return_state,
)
results.append(result)
results = [res if self._blocking else res() for res in results]
return ParallelEnv.unpack_transitions(results=results, return_states=_return_state)
[docs]
def sync_states(self, state, blocking: bool = True) -> None:
"""Set the same state to all the environments that are inside an external process.
Args:
state: Target state to set on the environments.
blocking: If ``True`` perform the update sequentially. If ``False`` step \
the environments in parallel.
Returns:
None.
"""
for env in self._envs:
try:
env.set_state(state, blocking=blocking)
except EOFError: # noqa: PERF203
continue
[docs]
def reset(self, indices=None, return_states: bool = True):
"""Reset the environment and return the resulting batch data.
Args:
indices: The batch indices of environments to reset; defaults to all.
return_states: return the corresponding states after reset.
Returns:
Batch of observations. If ``return_states`` is ``True`` return a tuple \
containing ``(batch_of_observations, batch_of_states)``.
"""
if indices is None:
indices = numpy.arange(len(self._envs))
trans = [
self._envs[index].reset(return_states=return_states, blocking=self._blocking)
for index in indices
]
if not self._blocking:
trans = [trans() for trans in trans]
if return_states:
states, obs, infos = zip(*trans)
states, obs, infos = numpy.array(states), numpy.stack(obs), numpy.array(infos)
return states, obs, infos
obs, infos = zip(*trans)
obs, infos = numpy.stack(obs), numpy.array(infos)
return obs, infos
[docs]
def close(self):
"""Send close messages to the external process and join them."""
for env in self._envs:
if hasattr(env, "close"):
env.close()
[docs]
class ParallelEnv(VectorizedEnv):
"""Allow any environment to be stepped in parallel when step_batch is called.
It creates a local instance of the target environment to call all other methods.
Example::
>>> from plangym.videogames import AtariEnv
>>> env = ParallelEnv(env_class=AtariEnv,
... name="MsPacman-v0",
... clone_seeds=True,
... autoreset=True,
... blocking=False)
>>>
>>> state, obs, info = env.reset()
>>>
>>> states = [state.copy() for _ in range(10)]
>>> actions = [env.sample_action() for _ in range(10)]
>>>
>>> data = env.step_batch(states=states, actions=actions)
>>> new_states, observs, rewards, ends, truncateds, infos = data
"""
def __init__(
self,
env_class,
name: str,
frameskip: int = 1,
autoreset: bool = True,
delay_setup: bool = False,
n_workers: int = 8,
blocking: bool = False,
**kwargs,
):
"""Initialize a :class:`ParallelEnv`.
Args:
env_class: Class of the environment to be wrapped.
name: Name of the environment.
frameskip: Number of times ``step`` will me called with the same action.
autoreset: Ignored. Always set to True. Automatically reset the environment
when the OpenAI environment returns ``end = True``.
delay_setup: If ``True`` do not initialize the ``gym.Environment`` \
and wait for ``setup`` to be called later.
env_callable: Callable that returns an instance of the environment \
that will be parallelized.
n_workers: Number of workers that will be used to step the env.
blocking: Step the environments synchronously.
*args: Additional args for the environment.
**kwargs: Additional kwargs for the environment.
"""
self._blocking = blocking
self._batch_env = None
super().__init__(
env_class=env_class,
name=name,
frameskip=frameskip,
autoreset=autoreset,
delay_setup=delay_setup,
n_workers=n_workers,
**kwargs,
)
@property
def blocking(self) -> bool:
"""If True the steps are performed sequentially."""
return self._blocking
[docs]
def setup(self):
"""Run environment initialization and create the subprocesses for stepping in parallel."""
external_callable = self.create_env_callable(autoreset=True, delay_setup=False)
envs = [ExternalProcess(constructor=external_callable) for _ in range(self.n_workers)]
self._batch_env = BatchEnv(envs, blocking=self._blocking)
# Initialize local copy last to tolerate singletons better
super().setup()
[docs]
def clone(self, **kwargs) -> "PlanEnv":
"""Return a copy of the environment."""
default_kwargs = {"blocking": self.blocking}
default_kwargs.update(kwargs)
return super().clone(**default_kwargs)
[docs]
def make_transitions(
self,
actions: numpy.ndarray,
states: numpy.ndarray = None,
dt: numpy.ndarray | int = 1,
return_state: bool | None = None,
):
"""Vectorized version of the ``step`` method.
It allows to step a vector of states and actions. The signature and
behaviour is the same as ``step``, but taking a list of states, actions
and dts as input.
Args:
actions: Iterable containing the different actions to be applied.
states: Iterable containing the different states to be set.
dt: int or array containing the frameskips that will be applied.
return_state: Whether to return the state in the returned tuple. \
If None, `step` will return the state if `state` was passed as a parameter.
Returns:
if states is None returns ``(observs, rewards, ends, truncateds, infos)`` else \
``(new_states, observs, rewards, ends, truncateds, infos)``
"""
return self._batch_env.make_transitions(
actions=actions,
states=states,
dt=dt,
return_state=return_state,
)
[docs]
def sync_states(self, state: None):
"""Synchronize all the copies of the wrapped environment.
Set all the states of the different workers of the internal :class:`BatchEnv`
to the same state as the internal :class:`Environment` used to apply the
non-vectorized steps.
"""
state = self.get_state() if state is None else state
self._batch_env.sync_states(state)
[docs]
def close(self) -> None:
"""Close the environment and the spawned processes."""
if hasattr(self._batch_env, "close"):
self._batch_env.close()
if hasattr(self.gym_env, "close"):
self.gym_env.close()