Skip to content

Enhance README for Tensor Parallelism #1369

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Enhance README and examples for Tensor Parallelism
- Added installation instructions and example running commands to README.md.
- Update files to have a better organization

Signed-off-by: jafraustro <jaime.fraustro.valdez@intel.com>
  • Loading branch information
jafraustro committed Jul 10, 2025
commit 6beb6bb7e32a63aa8a4b6d00de278bef556747ba
22 changes: 20 additions & 2 deletions distributed/tensor_parallelism/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,25 @@ PyTorch native Tensor Parallel APIs, which include:
More details about the PyTorch native Tensor Parallel APIs, please see PyTorch docs:
https://pytorch.org/docs/stable/distributed.tensor.parallel.html

```
## Installation

```bash
pip install -r requirements.txt
torchrun --nnodes 1 --nproc-per-node 4 tensor_parallel_example.py
```

## Running Examples

You can run the examples using `torchrun` to launch distributed training:

```bash
# Simple Tensor Parallel example
torchrun --nnodes=1 --nproc_per_node=4 tensor_parallel_example.py

# Tensor Parallel with Sequence Parallel
torchrun --nnodes=1 --nproc_per_node=4 sequence_parallel_example.py

# FSDP + Tensor Parallel with Llama2 model
torchrun --nnodes=1 --nproc_per_node=4 fsdp_tp_example.py
```

For more details, check the `run_examples.sh` script.
61 changes: 30 additions & 31 deletions distributed/tensor_parallelism/fsdp_tp_example.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,3 @@
import sys
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F

from log_utils import rank_log, get_logger, verify_min_gpu_count

# ---- GPU check ------------
_min_gpu_count = 4

if not verify_min_gpu_count(min_gpus=_min_gpu_count):
print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.")
sys.exit()
# ---------------------------

from llama2_model import Transformer, ModelArgs

from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed._tensor import Shard, Replicate
from torch.distributed.tensor.parallel import (
parallelize_module,
ColwiseParallel,
RowwiseParallel,
PrepareModuleInput,
SequenceParallel
)


"""
This is the script to test 2D Parallel which combines Tensor/Sequence
parallel with Fully Sharded Data Parallel (TP/SP + FSDP) on a example
Expand Down Expand Up @@ -60,6 +29,36 @@
https://pytorch.org/tutorials/intermediate/TP_tutorial.html
"""

import sys
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F

from log_utils import rank_log, get_logger, verify_min_gpu_count

# ---- GPU check ------------
_min_gpu_count = 4

if not verify_min_gpu_count(min_gpus=_min_gpu_count):
print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.")
sys.exit()
# ---------------------------

from llama2_model import Transformer, ModelArgs

from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed._tensor import Shard, Replicate
from torch.distributed.tensor.parallel import (
parallelize_module,
ColwiseParallel,
RowwiseParallel,
PrepareModuleInput,
SequenceParallel
)

tp_size = 2
logger = get_logger()

Expand Down
41 changes: 19 additions & 22 deletions distributed/tensor_parallelism/sequence_parallel_example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,22 @@
# The following is an example command to run this code
# torchrun --nnodes 1 --nproc-per-node 4 sequence_parallel_example.py
"""
This is the script to test Sequence Parallel(SP) on a toy model in a
Megetron-LM SPMD style. We show an E2E working flow from forward,
backward and optimization.

We use the example of two `nn.Linear` layers with an element-wise `nn.RELU`
in between to show an example of sequence parallel, which was proposed in paper:

https://arxiv.org/pdf/2205.05198.pdf.

Like tensor parallel, we parallelize the first linear layer by column
and also parallelize the second linear layer by row. But the input in each rank
now is different so that we need one all-gather for input and one reduce-scatter
in the end of the second linear layer.

The following is an example command to run this code
torchrun --nnodes 1 --nproc-per-node 4 sequence_parallel_example.py
"""

import os
import sys
import torch
Expand All @@ -24,28 +41,8 @@
sys.exit()
# ---------------------------


from torch.distributed._tensor.device_mesh import init_device_mesh



"""
This is the script to test Sequence Parallel(SP) on a toy model in a
Megetron-LM SPMD style. We show an E2E working flow from forward,
backward and optimization.

We use the example of two `nn.Linear` layers with an element-wise `nn.RELU`
in between to show an example of sequence parallel, which was proposed in paper:

https://arxiv.org/pdf/2205.05198.pdf.

Like tensor parallel, we parallelize the first linear layer by column
and also parallelize the second linear layer by row. But the input in each rank
now is different so that we need one all-gather for input and one reduce-scatter
in the end of the second linear layer.
"""


class ToyModel(nn.Module):
"""MLP based model"""

Expand Down
53 changes: 25 additions & 28 deletions distributed/tensor_parallelism/tensor_parallel_example.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,3 @@
# The following is an example command to run this code
# torchrun --nnodes 1 --nproc-per-node 4 tensor_parallel_example.py
import os
import sys
import torch
import torch.nn as nn

from torch.distributed.tensor.parallel import (
parallelize_module,
ColwiseParallel,
RowwiseParallel,
)

from log_utils import rank_log, get_logger, verify_min_gpu_count

# ---- GPU check ------------
_min_gpu_count = 2

if not verify_min_gpu_count(min_gpus=_min_gpu_count):
print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.")
sys.exit()
# ---------------------------

from torch.distributed._tensor.device_mesh import init_device_mesh




"""
This is the script to test Tensor Parallel(TP) on a toy model in a
Megetron-LM SPMD style. We show an E2E working flow from forward,
Expand Down Expand Up @@ -55,8 +27,33 @@
to use and our `parallelize_module` API will parse and parallelize the modules
based on the given `ParallelStyle`. We are using this PyTorch native Tensor
Parallelism APIs in this example to show users how to use them.

The following is an example command to run this code
torchrun --nnodes 1 --nproc-per-node 4 tensor_parallel_example.py
"""

import os
import sys
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.distributed.tensor.parallel import (
parallelize_module,
ColwiseParallel,
RowwiseParallel,
)
from log_utils import rank_log, get_logger, verify_min_gpu_count

# ---- GPU check ------------
_min_gpu_count = 2

if not verify_min_gpu_count(min_gpus=_min_gpu_count):
print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.")
sys.exit()
# ---------------------------

from torch.distributed._tensor.device_mesh import init_device_mesh

class ToyModel(nn.Module):
"""MLP based model"""

Expand Down
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy