Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix 2 bugs and refactor RunningMeanStd to support dict obs norm #695

Merged
merged 4 commits into from
Jul 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions docs/tutorials/cheatsheet.rst
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
Cheat Sheet
===========

This page shows some code snippets of how to use Tianshou to develop new algorithms / apply algorithms to new scenarios.
This page shows some code snippets of how to use Tianshou to develop new
algorithms / apply algorithms to new scenarios.

By the way, some of these issues can be resolved by using a ``gym.Wrapper``. It could be a universal solution in the policy-environment interaction. But you can also use the batch processor :ref:`preprocess_fn`.
By the way, some of these issues can be resolved by using a ``gym.Wrapper``.
It could be a universal solution in the policy-environment interaction. But
you can also use the batch processor :ref:`preprocess_fn` or vectorized
environment wrapper :class:`~tianshou.env.VectorEnvWrapper`.


.. _network_api:
Expand All @@ -22,6 +26,18 @@ Build New Policy
See :class:`~tianshou.policy.BasePolicy`.


.. _eval_policy:

Manually Evaluate Policy
------------------------

If you'd like to manually see the action generated by a well-trained agent:
::

# assume obs is a single environment observation
action = policy(Batch(obs=np.array([obs]))).act[0]


.. _customize_training:

Customize Training Process
Expand Down
6 changes: 6 additions & 0 deletions docs/tutorials/dqn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,12 @@ Watch the Agent's Performance
collector = ts.data.Collector(policy, env, exploration_noise=True)
collector.collect(n_episode=1, render=1 / 35)

If you'd like to manually see the action generated by a well-trained agent:
::

# assume obs is a single environment observation
action = policy(Batch(obs=np.array([obs]))).act[0]


.. _customized_trainer:

Expand Down
9 changes: 9 additions & 0 deletions test/base/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,14 @@ def assert_get(v, expected):
v.close()


def test_attr_unwrapped():
train_envs = DummyVectorEnv([lambda: gym.make("CartPole-v1")])
train_envs.set_env_attr("test_attribute", 1337)
assert train_envs.get_env_attr("test_attribute") == [1337]
assert hasattr(train_envs.workers[0].env, "test_attribute")
assert hasattr(train_envs.workers[0].env.unwrapped, "test_attribute")


def test_env_obs_dtype():
for obs_type in ["array", "object"]:
envs = SubprocVectorEnv(
Expand Down Expand Up @@ -349,6 +357,7 @@ def test_venv_wrapper_envpool_gym_reset_return_info():
test_venv_wrapper_envpool()
test_env_obs_dtype()
test_vecenv()
test_attr_unwrapped()
test_async_env()
test_async_check_id()
test_env_reset_optional_kwargs()
Expand Down
2 changes: 1 addition & 1 deletion tianshou/data/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def reset_env(self, gym_reset_kwargs: Optional[Dict[str, Any]] = None) -> None:
)
obs = processed_data.get("obs", obs)
info = processed_data.get("info", info)
self.data.info = info
self.data.info = info
else:
obs = rval
if self.preprocess_fn:
Expand Down
10 changes: 1 addition & 9 deletions tianshou/env/venv_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,24 +68,17 @@ class VectorEnvNormObs(VectorEnvWrapper):
"""An observation normalization wrapper for vectorized environments.

:param bool update_obs_rms: whether to update obs_rms. Default to True.
:param float clip_obs: the maximum absolute value for observation. Default to
10.0.
:param float epsilon: To avoid division by zero.
"""

def __init__(
self,
venv: BaseVectorEnv,
update_obs_rms: bool = True,
clip_obs: float = 10.0,
epsilon: float = np.finfo(np.float32).eps.item(),
) -> None:
super().__init__(venv)
# initialize observation running mean/std
self.update_obs_rms = update_obs_rms
self.obs_rms = RunningMeanStd()
self.clip_max = clip_obs
self.eps = epsilon

def reset(
self,
Expand Down Expand Up @@ -127,8 +120,7 @@ def step(

def _norm_obs(self, obs: np.ndarray) -> np.ndarray:
if self.obs_rms:
obs = (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.eps)
obs = np.clip(obs, -self.clip_max, self.clip_max)
return self.obs_rms.norm(obs) # type: ignore
return obs

def set_obs_rms(self, obs_rms: RunningMeanStd) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tianshou/env/worker/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def get_env_attr(self, key: str) -> Any:
return getattr(self.env, key)

def set_env_attr(self, key: str, value: Any) -> None:
setattr(self.env, key, value)
setattr(self.env.unwrapped, key, value)

def reset(self, **kwargs: Any) -> Union[np.ndarray, Tuple[np.ndarray, dict]]:
if "seed" in kwargs:
Expand Down
2 changes: 1 addition & 1 deletion tianshou/env/worker/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class _SetAttrWrapper(gym.Wrapper):

def set_env_attr(self, key: str, value: Any) -> None:
setattr(self.env, key, value)
setattr(self.env.unwrapped, key, value)

def get_env_attr(self, key: str) -> Any:
return getattr(self.env, key)
Expand Down
8 changes: 4 additions & 4 deletions tianshou/env/worker/subproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def _setup_buf(space: gym.Space) -> Union[dict, tuple, ShArray]:
if isinstance(space, gym.spaces.Dict):
assert isinstance(space.spaces, OrderedDict)
return {k: _setup_buf(v) for k, v in space.spaces.items()}
elif isinstance(space, gym.spaces.Tuple):
assert isinstance(space.spaces, tuple)
return tuple([_setup_buf(t) for t in space.spaces])
elif isinstance(space, gym.spaces.Tuple): # type: ignore
assert isinstance(space.spaces, tuple) # type: ignore
return tuple([_setup_buf(t) for t in space.spaces]) # type: ignore
else:
return ShArray(space.dtype, space.shape) # type: ignore

Expand Down Expand Up @@ -122,7 +122,7 @@ def _encode_obs(
elif cmd == "getattr":
p.send(getattr(env, data) if hasattr(env, data) else None)
elif cmd == "setattr":
setattr(env, data["key"], data["value"])
setattr(env.unwrapped, data["key"], data["value"])
else:
p.close()
raise NotImplementedError
Expand Down
20 changes: 18 additions & 2 deletions tianshou/utils/statistics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from numbers import Number
from typing import List, Union
from typing import List, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -70,15 +70,31 @@ class RunningMeanStd(object):
"""Calculates the running mean and std of a data stream.

https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm

:param mean: the initial mean estimation for data array. Default to 0.
:param std: the initial standard error estimation for data array. Default to 1.
:param float clip_max: the maximum absolute value for data array. Default to
10.0.
:param float epsilon: To avoid division by zero.
"""

def __init__(
self,
mean: Union[float, np.ndarray] = 0.0,
std: Union[float, np.ndarray] = 1.0
std: Union[float, np.ndarray] = 1.0,
clip_max: Optional[float] = 10.0,
epsilon: float = np.finfo(np.float32).eps.item(),
) -> None:
self.mean, self.var = mean, std
self.clip_max = clip_max
self.count = 0
self.eps = epsilon

def norm(self, data_array: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
data_array = (data_array - self.mean) / np.sqrt(self.var + self.eps)
if self.clip_max:
data_array = np.clip(data_array, -self.clip_max, self.clip_max)
return data_array

def update(self, data_array: np.ndarray) -> None:
"""Add a batch of item into RMS with the same shape, modify mean/var/count."""
Expand Down
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy