Skip to content

Commit 0a9d97e

Browse files
leonardhussenotcopybara-github
authored andcommitted
Adding a wrapper that concatenates observation in a single tensor.
It allows to run seamlessly on DMControl and Gym environments. PiperOrigin-RevId: 413688077 Change-Id: Id59f6bf2800088c71438e3a0e0eaa5d9debdf9ed
1 parent 32156b2 commit 0a9d97e

File tree

3 files changed

+99
-2
lines changed

3 files changed

+99
-2
lines changed

acme/wrappers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from acme.wrappers.base import EnvironmentWrapper
2020
from acme.wrappers.base import wrap_all
2121
from acme.wrappers.canonical_spec import CanonicalSpecWrapper
22+
from acme.wrappers.concatenate_observations import ConcatObservationWrapper
2223
from acme.wrappers.frame_stacking import FrameStackingWrapper
2324
from acme.wrappers.gym_wrapper import GymAtariAdapter
2425
from acme.wrappers.gym_wrapper import GymWrapper
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# python3
2+
# Copyright 2018 DeepMind Technologies Limited. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Wrapper that implements concatenation of observation fields."""
17+
18+
from typing import Sequence, Optional
19+
20+
from acme import types
21+
from acme.jax import utils
22+
from acme.wrappers import base
23+
import dm_env
24+
import numpy as np
25+
import tree
26+
27+
28+
def _concat(values: types.NestedArray) -> np.ndarray:
29+
"""Concatenates the leaves of `values` along the leading dimension.
30+
31+
Treats scalars as 1d arrays and expects that the shapes of all leaves are
32+
the same except for the leading dimension.
33+
34+
Args:
35+
values: the nested arrays to concatenate.
36+
Returns:
37+
The concatenated array.
38+
"""
39+
leaves = list(map(np.atleast_1d, tree.flatten(values)))
40+
return np.concatenate(leaves)
41+
42+
43+
class ConcatObservationWrapper(base.EnvironmentWrapper):
44+
"""Wrapper that concatenates observation fields.
45+
46+
It takes an environment with nested observations and concatenates the fields
47+
in a single tensor. The orginial fields should be 1-dimensional.
48+
Observation fields that are not in name_filter are dropped.
49+
"""
50+
51+
def __init__(self, environment: dm_env.Environment,
52+
name_filter: Optional[Sequence[str]] = None):
53+
"""Initializes a new ConcatObservationWrapper.
54+
55+
Args:
56+
environment: Environment to wrap.
57+
name_filter: Sequence of observation names to keep. None keeps them all.
58+
"""
59+
super().__init__(environment)
60+
observation_spec = environment.observation_spec()
61+
if name_filter is None:
62+
name_filter = list(observation_spec.keys())
63+
self._obs_names = [x for x in name_filter if x in observation_spec.keys()]
64+
65+
dummy_obs = utils.zeros_like(observation_spec)
66+
dummy_obs = self._convert_observation(dummy_obs)
67+
self._observation_spec = dm_env.specs.BoundedArray(
68+
shape=dummy_obs.shape,
69+
dtype=dummy_obs.dtype,
70+
minimum=-np.inf,
71+
maximum=np.inf,
72+
name='state')
73+
74+
def _convert_observation(self, observation):
75+
obs = {k: observation[k] for k in self._obs_names}
76+
return _concat(obs)
77+
78+
def step(self, action) -> dm_env.TimeStep:
79+
timestep = self._environment.step(action)
80+
return timestep._replace(
81+
observation=self._convert_observation(timestep.observation))
82+
83+
def reset(self) -> dm_env.TimeStep:
84+
timestep = self._environment.reset()
85+
return timestep._replace(
86+
observation=self._convert_observation(timestep.observation))
87+
88+
def observation_spec(self) -> types.NestedSpec:
89+
return self._observation_spec

examples/control/helpers.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,19 @@
2222

2323
def make_environment(evaluation: bool = False,
2424
domain_name: str = 'cartpole',
25-
task_name: str = 'balance') -> dm_env.Environment:
25+
task_name: str = 'balance',
26+
concatenate_observations: bool = False
27+
) -> dm_env.Environment:
2628
"""Implements a control suite environment factory."""
2729
# Nothing special to be done for evaluation environment.
2830
del evaluation
2931

3032
environment = suite.load(domain_name, task_name)
3133
environment = wrappers.SinglePrecisionWrapper(environment)
32-
34+
timestep = environment.reset()
35+
obs_names = list(timestep.observation.keys())
36+
if concatenate_observations:
37+
environment = wrappers.ConcatObservationWrapper(environment, obs_names)
3338
return environment
39+
40+

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy