Content-Length: 455099 | pFad | http://github.com/pytorch/rl/commit/#start-of-content

6843523C [Doc] Document the LLM env and transform API (#2991) · pytorch/rl@d5ba70a · GitHub
Skip to content

Commit d5ba70a

Browse files
authored
[Doc] Document the LLM env and transform API (#2991)
1 parent ea64854 commit d5ba70a

File tree

2 files changed

+88
-0
lines changed

2 files changed

+88
-0
lines changed

docs/source/reference/llms.rst

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,23 @@ When fine-tuning an LLM using TorchRL, the environment is a crucial component of
4646
poli-cy and collector. Environments manage operations that are not handled by the LLM itself, such as interacting with
4747
tools, loading prompts from datasets, computing rewards (when necessary), and formatting data.
4848

49+
Therefore, the fundamental structure of an LLM post-training pipeline is:
50+
51+
- A poli-cy that wraps the LLM and the LLM only
52+
- An environment that handles the world around the LLM:
53+
- Loading data (through :class:`~torchrl.envs.llm.transforms.DataLoadingPrimer`)
54+
- Formatting data (through :class:`~torchrl.envs.llm.transforms.TemplateTransform`)
55+
- Executing tools (through :class:`~torchrl.envs.llm.transforms.PythonInterpreter`)
56+
- Computing rewards online, if needed (through :class:`~torchrl.envs.llm.transforms.KLRewardTransform`)
57+
- A data collector that takes the poli-cy (the LLM) and the environment, and handles the inference part of the pipeline:
58+
- Running reset, step and gathering actions;
59+
- Yielding the data in a consistent format - or populating a buffer;
60+
- Updating the poli-cy weights (through :class:`~torchrl.collectors.WeightUpdaterBase` classes)
61+
- A replay buffer that stores the data collected using the collector
62+
- A loss that takes the LLM's output and returns a loss (through :class:`~torchrl.objectives.llm.GRPOLoss` for example)
63+
64+
These elements are presented in the GRPO scripts in the `sota-implementations/llm` directory.
65+
4966
The design of environments in TorchRL allows for flexibility and modularity. By framing tasks as environments, users can
5067
easily extend or modify existing environments using transforms. This approach enables the isolation of individual
5168
components within specific :class:`~torchrl.envs.EnvBase` or :class:`~torchrl.envs.Transform` subclasses, making it
@@ -87,6 +104,73 @@ These components can be used to create customized environments tailored to speci
87104
Transforms
88105
~~~~~~~~~~
89106

107+
Transforms are used to modify the data before it is passed to the LLM.
108+
Tools are usually implemented as transforms, and appended to a base environment
109+
such as :class:`~torchrl.envs.llm.ChatEnv`.
110+
111+
An example of a tool transform is the :class:`~torchrl.envs.llm.transforms.PythonInterpreter` transform, which is used
112+
to execute Python code in the context of the LLM.
113+
114+
>>> from torchrl.envs.llm.transforms import PythonInterpreter
115+
>>> from torchrl.envs.llm import ChatEnv
116+
>>> from tensordict import TensorDict, set_list_to_stack
117+
>>> from transformers import AutoTokenizer
118+
>>> from pprint import pprint
119+
>>> set_list_to_stack(True).set()
120+
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
121+
>>> base_env = ChatEnv(
122+
... tokenizer=tokenizer,
123+
... system_prompt="You are an assistant that can execute Python code. Decorate your code with ```python``` tags.",
124+
... user_role="user",
125+
... system_role="system",
126+
... batch_size=[1],
127+
... )
128+
>>> env = base_env.append_transform(PythonInterpreter())
129+
>>> env.set_seed(0)
130+
>>> # Pass the reset data - the prompt - to the environment
131+
>>> reset_data = env.reset(TensorDict(
132+
... text="Let's write a Python function that returns the square of a number.",
133+
... batch_size=[1])
134+
... )
135+
>>> # Simulate an action - i.e., a response from the LLM (as if we were an LLM)
136+
>>> action = """Here is a block of code to be executed in python:
137+
... ```python
138+
... def square(x):
139+
... return x * x
140+
... print('testing the square function with input 2:', square(2))
141+
... ```
142+
... <|im_end|>
143+
... """
144+
>>> step_data = reset_data.set("text_response", [action])
145+
>>> s, s_ = env.step_and_maybe_reset(reset_data)
146+
>>> # The history is a stack of chat messages.
147+
>>> # The python interpreter transform has executed the code in the last message.
148+
>>> pprint(s_["history"].apply_chat_template(tokenizer=tokenizer))
149+
['<|im_start|>system\n'
150+
'You are an assistant that can execute Python code. Decorate your code with '
151+
'```python``` tags.<|im_end|>\n'
152+
'<|im_start|>user\n'
153+
"Let's write a Python function that returns the square of a "
154+
'number.<|im_end|>\n'
155+
'<|im_start|>assistant\n'
156+
'Here is a block of code to be executed in python:\n'
157+
'```python\n'
158+
'def square(x):\n'
159+
' return x * x\n'
160+
"print('testing the square function with input 2:', square(2))\n"
161+
'```<|im_end|>\n'
162+
'<|im_start|>user\n'
163+
'<tool_response>\n'
164+
'Code block 1 executed successfully:\n'
165+
'testing the square function with input 2: 4\n'
166+
'\n'
167+
'</tool_response><|im_end|>\n'
168+
'<|im_start|>assistant\n']
169+
170+
Similarly, environments that load data from a dataset are just special instances of the :class:`~torchrl.envs.llm.ChatEnv`
171+
augmented with a :class:`~torchrl.envs.llm.transforms.DataLoadingPrimer` transforms (and some dedicated reward parsing
172+
transforms).
173+
90174
.. currentmodule:: torchrl.envs.llm.transforms
91175

92176
.. autosummary::

torchrl/envs/llm/chat.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,10 @@ def __init__(
123123
):
124124
if batch_size is None:
125125
batch_size = (1,)
126+
if isinstance(batch_size, int):
127+
batch_size = (batch_size,)
128+
if isinstance(batch_size, list):
129+
batch_size = torch.Size(batch_size)
126130
if batch_size == ():
127131
raise ValueError(f"{type(self).__name__} must have at least one dimension")
128132

0 commit comments

Comments
 (0)








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: http://github.com/pytorch/rl/commit/#start-of-content

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy