From df4a3e0eb473bdfb9bb87d94c5cf0083ebe988e4 Mon Sep 17 00:00:00 2001 From: yannbouteiller Date: Sat, 27 Mar 2021 22:37:41 -0400 Subject: [PATCH] support for info dict (backward incompatibile) --- README.md | 24 +++++++++++++----------- rtgym/envs/real_time_env.py | 27 ++++++++++++++------------- rtgym/tuto/tuto.py | 5 +++-- tests/test_all.py | 4 ++-- 4 files changed, 32 insertions(+), 28 deletions(-) diff --git a/README.md b/README.md index e8f5869..418750f 100644 --- a/README.md +++ b/README.md @@ -59,7 +59,7 @@ This method defines the core mechanism of Real-Time Gym environments: Time-steps are being elastically constrained to their nominal duration. When this elastic constraint cannot be satisfied, the previous time-step times out and the new time-step starts from the current timestamp. This happens either because the environment has been 'paused', or because the system is ill-designed: - The inference duration of the model, i.e. the elapsed duration between two calls of the step() function, may be too long for the time-step duration that the user is trying to use. -- The procedure that retrieves observations may take too much time or may be called too late (the latter can be tweaked in the configuration dictionary). Remember that, if observation capture is too long, it must not be part of the get_obs_rew_done() method of your interface. Instead, this method must simply retrieve the latest available observation from another process, and the action buffer must be long enough to handle the observation capture duration. This is described in the Appendix of [Reinforcement Learning with Random Delays](https://arxiv.org/abs/2010.02966). +- The procedure that retrieves observations may take too much time or may be called too late (the latter can be tweaked in the configuration dictionary). Remember that, if observation capture is too long, it must not be part of the get_obs_rew_done_info() method of your interface. Instead, this method must simply retrieve the latest available observation from another process, and the action buffer must be long enough to handle the observation capture duration. This is described in the Appendix of [Reinforcement Learning with Random Delays](https://arxiv.org/abs/2010.02966). ## Tutorial @@ -96,7 +96,7 @@ from rtgym import RealTimeGymInterface The [RealTimeGymInterface](https://github.com/yannbouteiller/rtgym/blob/969799b596e91808543f781b513901426b88d138/rtgym/envs/real_time_env.py#L12) is all you need to implement in order to create your custom real-time Gym environment. -This class has 6 abstract methods that you need to implement: ```get_observation_space```, ```get_action_space```, ```get_default_action```, ```reset```, ```get_obs_rew_done``` and ```send_control```. +This class has 6 abstract methods that you need to implement: ```get_observation_space```, ```get_action_space```, ```get_default_action```, ```reset```, ```get_obs_rew_done_info``` and ```send_control```. It also has a ```wait``` and a ```render``` methods that you may want to override. We will implement them all to understand their respective roles. @@ -218,7 +218,7 @@ class MyRealTimeInterface(RealTimeGymInterface): def reset(self): pass - def get_obs_rew_done(self): + def get_obs_rew_done_info(self): pass def wait(self): @@ -305,7 +305,7 @@ Ok, in this case this is actually equivalent, but you get the idea. You may want --- The ```get_observation_space``` method outputs a ```gym.spaces.Tuple``` object. -This object describes the structure of the observations returned from the ```reset``` and ```get_obs_rew_done``` methods of our interface. +This object describes the structure of the observations returned from the ```reset``` and ```get_obs_rew_done_info``` methods of our interface. In our case, the observation will contain ```pos_x``` and ```pos_y```, which are both constrained between ```-1.0``` and ```1.0``` in our simple 2D world. It will also contain target coordinates ```tar_x``` and ```tar_y```, constrained between ```-0.5``` and ```0.5```. @@ -324,7 +324,7 @@ def get_observation_space(self): --- We can now implement the RL mechanics of our environment (i.e. the reward function and whether we consider the task ```done``` in the episodic setting), and a procedure to retrieve observations from our dummy drone. -This is done in the ```get_obs_rew_done``` method. +This is done in the ```get_obs_rew_done_info``` method. For this tutorial, we will implement a simple task. @@ -339,22 +339,24 @@ The task is easy, but not as straightforward as it looks. Indeed, the presence of random communication delays and the fact that the drone keeps moving in real time makes it difficult to precisely reach the target. --- -```get_obs_rew_done``` outputs 3 values: +```get_obs_rew_done_info``` outputs 4 values: - ```obs```: a list of all the components of the last retrieved observation, except for the action buffer - ```rew```: a float that is our reward - ```done```: a boolean that tells whether the episode is finished (always False in the non-episodic setting) +- ```info```: a dictionary that contains any additional information you may want to provide For our simple task, the implementation is fairly straightforward. -```obs``` contains the last available coordinates and the target, ```rew``` is the negative distance to the target, and ```done``` is True when the target has been reached: +```obs``` contains the last available coordinates and the target, ```rew``` is the negative distance to the target, ```done``` is True when the target has been reached, and since we don't need more information ```info``` is empty: ```python -def get_obs_rew_done(self): +def get_obs_rew_done_info(self): pos_x, pos_y = self.rc_drone.get_observation() tar_x = self.target[0] tar_y = self.target[1] obs = [pos_x, pos_y, tar_x, tar_y] rew = -np.linalg.norm(np.array([pos_x, pos_y], dtype=np.float32) - self.target) done = rew > -0.01 - return obs, rew, done + info = {} + return obs, rew, done, info ``` We did not implement the 100 time-steps limit here because this will be done later in the configuration dictionary. @@ -416,8 +418,8 @@ The ```rtgym``` environment will ensure that the control frequency sticks to thi The ```"start_obs_capture"``` entry is usually the same as the ```"time_step_duration"``` entry. It defines the time at which an observation starts being retrieved, which should usually happen instantly at the end of the time-step. -However, in some situations, you will want to actually capture an observation in ```get_obs_rew_done``` and the capture duration will not be negligible. -In such situations, if observation capture is less than 1 time-step, you can do this and use ```"start_obs_capture"``` in order to tell the environment to call ```get_obs_rew_done``` before the end of the time-step. +However, in some situations, you will want to actually capture an observation in ```get_obs_rew_done_info``` and the capture duration will not be negligible. +In such situations, if observation capture is less than 1 time-step, you can do this and use ```"start_obs_capture"``` in order to tell the environment to call ```get_obs_rew_done_info``` before the end of the time-step. If observation capture is more than 1 time-step, it needs to be performed in a parallel process and the last available observation should be used at each time-step. In any case, keep in mind that when observation capture is not instantaneous, you should add its maximum duration to the maximum delay, and increase the size of the action buffer accordingly. See the [Reinforcement Learning with Random Delays](https://arxiv.org/abs/2010.02966) appendix for more details. diff --git a/rtgym/envs/real_time_env.py b/rtgym/envs/real_time_env.py index 9519cee..4b96c90 100644 --- a/rtgym/envs/real_time_env.py +++ b/rtgym/envs/real_time_env.py @@ -64,17 +64,18 @@ def wait(self): """ self.send_control(self.get_default_action()) - def get_obs_rew_done(self): - """Returns observation, reward and done from the device. + def get_obs_rew_done_info(self): + """Returns observation, reward, done and info from the device. Returns: obs: list rew: scalar - done: boolean + done: bool + info: dict Note: Do NOT put the action buffer in obs (automated). """ - # return obs, rew, done + # return obs, rew, done, info raise NotImplementedError @@ -129,7 +130,7 @@ def render(self): # start_obs_capture should be the same as "time_step_duration" unless observation capture is non-instantaneous and # smaller than one time-step, and you want to capture it directly in your interface for convenience. Otherwise, # you need to perform observation capture in a parallel process and simply retrieve the last available observation - # in the get_obs_rew_done() and reset() methods of your interface + # in the get_obs_rew_done_info() and reset() methods of your interface "time_step_timeout_factor": 1.0, # maximum elasticity in (fraction or number of) time-steps "ep_max_length": 1000, # maximum episode length "real_time": True, # True unless you want to revert to the usual turn-based RL setting (not tested yet) @@ -317,6 +318,7 @@ def __init__(self, config: dict=DEFAULT_CONFIG_DICT): self.__obs = None self.__rew = None self.__done = None + self.__info = None self.__o_set_flag = False # environment benchmark: @@ -431,31 +433,31 @@ def __update_obs_rew_done(self): observation of this step() """ self.__o_lock.acquire() - o, r, d = self.interface.get_obs_rew_done() + o, r, d, i = self.interface.get_obs_rew_done_info() if not d: d = (self.current_step >= self.ep_max_length) elt = o if self.obs_prepro_func: elt = self.obs_prepro_func(elt) elt = tuple(elt) - self.__obs, self.__rew, self.__done = elt, r, d + self.__obs, self.__rew, self.__done, self.__info = elt, r, d, i self.__o_set_flag = True self.__o_lock.release() - def _retrieve_obs_rew_done(self): - """Waits for new available o r d and retrieves them. + def _retrieve_obs_rew_done_info(self): + """Waits for new available o r d i and retrieves them. """ c = True while c: self.__o_lock.acquire() if self.__o_set_flag: - elt, r, d = self.__obs, self.__rew, self.__done + elt, r, d, i = self.__obs, self.__rew, self.__done, self.__info self.__o_set_flag = False c = False self.__o_lock.release() if self.act_in_obs: elt = tuple((*elt, *tuple(self.act_buf),)) - return elt, r, d + return elt, r, d, i def init_action_buffer(self): for _ in range(self.act_buf_len): @@ -502,8 +504,7 @@ def step(self, action): self.act_buf.append(action) if not self.real_time: self._run_time_step(action) - obs, rew, done = self._retrieve_obs_rew_done() - info = {} + obs, rew, done, info = self._retrieve_obs_rew_done_info() if self.real_time: self._run_time_step(action) if done and self.wait_on_done: diff --git a/rtgym/tuto/tuto.py b/rtgym/tuto/tuto.py index 7edffbb..cb2aa4e 100644 --- a/rtgym/tuto/tuto.py +++ b/rtgym/tuto/tuto.py @@ -57,14 +57,15 @@ def reset(self): self.target[1] = np.random.uniform(-0.5, 0.5) return [pos_x, pos_y, self.target[0], self.target[1]] - def get_obs_rew_done(self): + def get_obs_rew_done_info(self): pos_x, pos_y = self.rc_drone.get_observation() tar_x = self.target[0] tar_y = self.target[1] obs = [pos_x, pos_y, tar_x, tar_y] rew = -np.linalg.norm(np.array([pos_x, pos_y], dtype=np.float32) - self.target) done = rew > -0.01 - return obs, rew, done + info = {} + return obs, rew, done, info def wait(self): self.send_control(self.get_default_action()) diff --git a/tests/test_all.py b/tests/test_all.py index d4b39b2..05c8b2c 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -17,8 +17,8 @@ def send_control(self, control): def reset(self): return [time.time(), self.control, self.control_time] - def get_obs_rew_done(self): - return [time.time(), self.control, self.control_time], 0.0, False + def get_obs_rew_done_info(self): + return [time.time(), self.control, self.control_time], 0.0, False, {} def get_observation_space(self): ob = gym.spaces.Box(low=np.array([0.0]), high=np.array([np.inf]), dtype=np.float32) pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy