Skip to content

Commit

Permalink
Add Trainers as generators (#559)
Browse files Browse the repository at this point in the history
The new proposed feature is to have trainers as generators.
The usage pattern is:

```python
trainer = OnPolicyTrainer(...)
for epoch, epoch_stat, info in trainer:
    print(f"Epoch: {epoch}")
    print(epoch_stat)
    print(info)
    do_something_with_policy()
    query_something_about_policy()
    make_a_plot_with(epoch_stat)
    display(info)
```

- epoch int: the epoch number
- epoch_stat dict: a large collection of metrics of the current epoch, including stat
- info dict: the usual dict out of the non-generator version of the trainer

You can even iterate on several different trainers at the same time:

```python
trainer1 = OnPolicyTrainer(...)
trainer2 = OnPolicyTrainer(...)
for result1, result2, ... in zip(trainer1, trainer2, ...):
    compare_results(result1, result2, ...)
```

Co-authored-by: Jiayi Weng <trinkle23897@gmail.com>
  • Loading branch information
jamartinh and Trinkle23897 authored Mar 17, 2022
1 parent 2336a7d commit 10d9190
Show file tree
Hide file tree
Showing 14 changed files with 864 additions and 488 deletions.
4 changes: 2 additions & 2 deletions .github/ISSUE_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@
- [ ] I have searched through the [issue tracker](https://github.com/thu-ml/tianshou/issues) for duplicates
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
```python
import tianshou, torch, numpy, sys
print(tianshou.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform)
import tianshou, gym, torch, numpy, sys
print(tianshou.__version__, gym.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform)
```
4 changes: 1 addition & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@ lint:
flake8 ${LINT_PATHS} --count --show-source --statistics

format:
# sort imports
$(call check_install, isort)
isort ${LINT_PATHS}
# reformat using yapf
$(call check_install, yapf)
yapf -ir ${LINT_PATHS}

Expand Down Expand Up @@ -57,6 +55,6 @@ doc-clean:

clean: doc-clean

commit-checks: format lint mypy check-docstyle spelling
commit-checks: lint check-codestyle mypy check-docstyle spelling

.PHONY: clean spelling doc mypy lint format check-codestyle check-docstyle commit-checks
44 changes: 43 additions & 1 deletion docs/api/tianshou.trainer.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,49 @@
tianshou.trainer
================

.. automodule:: tianshou.trainer

On-policy
---------

.. autoclass:: tianshou.trainer.OnpolicyTrainer
:members:
:undoc-members:
:show-inheritance:

.. autofunction:: tianshou.trainer.onpolicy_trainer

.. autoclass:: tianshou.trainer.onpolicy_trainer_iter


Off-policy
----------

.. autoclass:: tianshou.trainer.OffpolicyTrainer
:members:
:undoc-members:
:show-inheritance:

.. autofunction:: tianshou.trainer.offpolicy_trainer

.. autoclass:: tianshou.trainer.offpolicy_trainer_iter


Offline
-------

.. autoclass:: tianshou.trainer.OfflineTrainer
:members:
:undoc-members:
:show-inheritance:

.. autofunction:: tianshou.trainer.offline_trainer

.. autoclass:: tianshou.trainer.offline_trainer_iter


utils
-----

.. autofunction:: tianshou.trainer.test_episode

.. autofunction:: tianshou.trainer.gather_info
3 changes: 3 additions & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@ fqf
iqn
qrdqn
rl
offpolicy
onpolicy
quantile
quantiles
dqn
param
async
subprocess
deque
nn
equ
cql
Expand Down
20 changes: 20 additions & 0 deletions docs/tutorials/concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,26 @@ Once you have a collector and a policy, you can start writing the training metho

Tianshou has three types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` for on-policy algorithms such as Policy Gradient, :func:`~tianshou.trainer.offpolicy_trainer` for off-policy algorithms such as DQN, and :func:`~tianshou.trainer.offline_trainer` for offline algorithms such as BCQ. Please check out :doc:`/api/tianshou.trainer` for the usage.

We also provide the corresponding iterator-based trainer classes :class:`~tianshou.trainer.OnpolicyTrainer`, :class:`~tianshou.trainer.OffpolicyTrainer`, :class:`~tianshou.trainer.OfflineTrainer` to facilitate users writing more flexible training logic:
::

trainer = OnpolicyTrainer(...)
for epoch, epoch_stat, info in trainer:
print(f"Epoch: {epoch}")
print(epoch_stat)
print(info)
do_something_with_policy()
query_something_about_policy()
make_a_plot_with(epoch_stat)
display(info)

# or even iterate on several trainers at the same time

trainer1 = OnpolicyTrainer(...)
trainer2 = OnpolicyTrainer(...)
for result1, result2, ... in zip(trainer1, trainer2, ...):
compare_results(result1, result2, ...)


.. _pseudocode:

Expand Down
14 changes: 10 additions & 4 deletions test/continuous/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.policy import PPOPolicy
from tianshou.trainer import onpolicy_trainer
from tianshou.trainer import OnpolicyTrainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.continuous import ActorProb, Critic
Expand Down Expand Up @@ -157,7 +157,7 @@ def save_checkpoint_fn(epoch, env_step, gradient_step):
print("Fail to restore policy and optim.")

# trainer
result = onpolicy_trainer(
trainer = OnpolicyTrainer(
policy,
train_collector,
test_collector,
Expand All @@ -173,10 +173,16 @@ def save_checkpoint_fn(epoch, env_step, gradient_step):
resume_from_log=args.resume,
save_checkpoint_fn=save_checkpoint_fn
)
assert stop_fn(result['best_reward'])

for epoch, epoch_stat, info in trainer:
print(f"Epoch: {epoch}")
print(epoch_stat)
print(info)

assert stop_fn(info["best_reward"])

if __name__ == '__main__':
pprint.pprint(result)
pprint.pprint(info)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
Expand Down
2 changes: 1 addition & 1 deletion test/continuous/test_sac_with_il.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--reward-threshold', type=float, default=None)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--actor-lr', type=float, default=1e-3)
parser.add_argument('--critic-lr', type=float, default=1e-3)
Expand Down
19 changes: 12 additions & 7 deletions test/continuous/test_td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tianshou.env import DummyVectorEnv
from tianshou.exploration import GaussianNoise
from tianshou.policy import TD3Policy
from tianshou.trainer import offpolicy_trainer
from tianshou.trainer import OffpolicyTrainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import Actor, Critic
Expand Down Expand Up @@ -135,8 +135,8 @@ def save_fn(policy):
def stop_fn(mean_rewards):
return mean_rewards >= args.reward_threshold

# trainer
result = offpolicy_trainer(
# Iterator trainer
trainer = OffpolicyTrainer(
policy,
train_collector,
test_collector,
Expand All @@ -148,12 +148,17 @@ def stop_fn(mean_rewards):
update_per_step=args.update_per_step,
stop_fn=stop_fn,
save_fn=save_fn,
logger=logger
logger=logger,
)
assert stop_fn(result['best_reward'])
for epoch, epoch_stat, info in trainer:
print(f"Epoch: {epoch}")
print(epoch_stat)
print(info)

if __name__ == '__main__':
pprint.pprint(result)
assert stop_fn(info["best_reward"])

if __name__ == "__main__":
pprint.pprint(info)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
Expand Down
16 changes: 11 additions & 5 deletions test/offline/test_cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.policy import CQLPolicy
from tianshou.trainer import offline_trainer
from tianshou.trainer import OfflineTrainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import ActorProb, Critic
Expand Down Expand Up @@ -195,7 +195,7 @@ def watch():
collector.collect(n_episode=1, render=1 / 35)

# trainer
result = offline_trainer(
trainer = OfflineTrainer(
policy,
buffer,
test_collector,
Expand All @@ -207,11 +207,17 @@ def watch():
stop_fn=stop_fn,
logger=logger,
)
assert stop_fn(result['best_reward'])

for epoch, epoch_stat, info in trainer:
print(f"Epoch: {epoch}")
print(epoch_stat)
print(info)

assert stop_fn(info["best_reward"])

# Let's watch its performance!
if __name__ == '__main__':
pprint.pprint(result)
if __name__ == "__main__":
pprint.pprint(info)
env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env)
Expand Down
30 changes: 24 additions & 6 deletions tianshou/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,34 @@
"""Trainer package."""

# isort:skip_file

from tianshou.trainer.utils import test_episode, gather_info
from tianshou.trainer.onpolicy import onpolicy_trainer
from tianshou.trainer.offpolicy import offpolicy_trainer
from tianshou.trainer.offline import offline_trainer
from tianshou.trainer.base import BaseTrainer
from tianshou.trainer.offline import (
OfflineTrainer,
offline_trainer,
offline_trainer_iter,
)
from tianshou.trainer.offpolicy import (
OffpolicyTrainer,
offpolicy_trainer,
offpolicy_trainer_iter,
)
from tianshou.trainer.onpolicy import (
OnpolicyTrainer,
onpolicy_trainer,
onpolicy_trainer_iter,
)
from tianshou.trainer.utils import gather_info, test_episode

__all__ = [
"BaseTrainer",
"offpolicy_trainer",
"offpolicy_trainer_iter",
"OffpolicyTrainer",
"onpolicy_trainer",
"onpolicy_trainer_iter",
"OnpolicyTrainer",
"offline_trainer",
"offline_trainer_iter",
"OfflineTrainer",
"test_episode",
"gather_info",
]
Loading

0 comments on commit 10d9190

Please sign in to comment.
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy