diff --git a/run_python_examples.sh b/run_python_examples.sh index 2d769c0ae1..ef6d0228c5 100755 --- a/run_python_examples.sh +++ b/run_python_examples.sh @@ -167,6 +167,10 @@ function gat() { uv run main.py --epochs 1 --dry-run || error "graph attention network failed" } +function swin() { + uv run swin_transformer.py --epochs 1 --dry-run || error "swin transformer failed" +} + eval "base_$(declare -f stop)" function stop() { @@ -191,8 +195,8 @@ function stop() { time_sequence_prediction/traindata.pt \ word_language_model/model.pt \ gcn/cora/ \ - gat/cora/ || error "couldn't clean up some files" - + gat/cora/ \ + swin_trasformer/swin_cifar10.pt || error "couldn't clean up some files" git checkout fast_neural_style/images/output-images/amber-candy.jpg || error "couldn't clean up fast neural style image" base_stop "$1" @@ -220,6 +224,7 @@ function run_all() { run fx run gcn run gat + run swin_transformer } # by default, run all examples diff --git a/swin_transformer/README.md b/swin_transformer/README.md new file mode 100644 index 0000000000..37f789be37 --- /dev/null +++ b/swin_transformer/README.md @@ -0,0 +1,61 @@ +# Swin Transformer on CIFAR-10 + +This project demonstrates a minimal implementation of a **Swin Transformer** for image classification on the **CIFAR-10** dataset using PyTorch. + +It includes: +- Patch embedding and window-based self-attention +- Shifted windows for hierarchical representation +- Training and testing logic using standard PyTorch utilities + +--- + +## Files + +- `swin_transformer.py` — Full implementation of the Swin Transformer model, training loop, and evaluation on CIFAR-10. +- `README.md` — This file. + +--- + +## Requirements + +- Python 3.8+ +- PyTorch 2.6 or later +- `torchvision` (for CIFAR-10 dataset) + +Install dependencies: + +```bash +pip install -r requirements.txt +``` + +--- + +## Usage + +### Train & Save the model + +```bash +python swin_transformer.py --epochs 10 --batch-size 64 --lr 0.001 --save-model +``` + +### Test the model + +Testing is done automatically after each epoch. To only test, run with: + +```bash +python swin_transformer.py --epochs 1 +`` + +The model will be saved as `swin_cifar10.pt`. + +--- + +## Features + +- Uses shifted window attention for local self-attention. +- Patch-based embedding with a lightweight network. +- Trains on CIFAR-10 with `Adam` optimizer and learning rate scheduling. +- Prints loss and accuracy per epoch. + +--- + diff --git a/swin_transformer/requirements.txt b/swin_transformer/requirements.txt new file mode 100644 index 0000000000..9a083ba390 --- /dev/null +++ b/swin_transformer/requirements.txt @@ -0,0 +1,2 @@ +torch>=2.6 +torchvision diff --git a/swin_transformer/swin_transformer.py b/swin_transformer/swin_transformer.py new file mode 100644 index 0000000000..a29fbd5fff --- /dev/null +++ b/swin_transformer/swin_transformer.py @@ -0,0 +1,203 @@ +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.optim.lr_scheduler import StepLR +from torchvision import datasets, transforms + +# ---------- Core Swin Components ---------- + +class PatchEmbed(nn.Module): + def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=48): + super().__init__() + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = nn.LayerNorm(embed_dim) + + def forward(self, x): + x = self.proj(x) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + return x + +def window_partition(x, window_size): + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + +def window_reverse(windows, window_size, H, W): + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + +class WindowAttention(nn.Module): + def __init__(self, dim, window_size, num_heads): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3) + self.proj = nn.Linear(dim, dim) + + def forward(self, x): + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads) + q, k, v = qkv.permute(2, 0, 3, 1, 4) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + + out = (attn @ v).transpose(1, 2).reshape(B_, N, C) + return self.proj(out) + +class SwinTransformerBlock(nn.Module): + def __init__(self, dim, input_resolution, num_heads, window_size=4, shift_size=0): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.window_size = window_size + self.shift_size = shift_size + + self.norm1 = nn.LayerNorm(dim) + self.attn = WindowAttention(dim, window_size, num_heads) + self.norm2 = nn.LayerNorm(dim) + + self.mlp = nn.Sequential( + nn.Linear(dim, dim * 4), + nn.GELU(), + nn.Linear(dim * 4, dim) + ) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + x = x.view(B, H, W, C) + + if self.shift_size > 0: + shifted_x = torch.roll(x, (-self.shift_size, -self.shift_size), (1, 2)) + else: + shifted_x = x + + windows = window_partition(shifted_x, self.window_size) + windows = windows.view(-1, self.window_size * self.window_size, C) + + attn_windows = self.attn(self.norm1(windows)) + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + + shifted_x = window_reverse(attn_windows, self.window_size, H, W) + + if self.shift_size > 0: + x = torch.roll(shifted_x, (self.shift_size, self.shift_size), (1, 2)) + else: + x = shifted_x + + x = x.view(B, H * W, C) + x = x + self.mlp(self.norm2(x)) + return x + +# ---------- Final Network ---------- + +class SwinTinyNet(nn.Module): + def __init__(self, num_classes=10): + super(SwinTinyNet, self).__init__() + self.patch_embed = PatchEmbed(img_size=32, patch_size=4, in_chans=3, embed_dim=48) + self.block1 = SwinTransformerBlock(dim=48, input_resolution=(8, 8), num_heads=3, window_size=4, shift_size=0) + self.block2 = SwinTransformerBlock(dim=48, input_resolution=(8, 8), num_heads=3, window_size=4, shift_size=2) + self.norm = nn.LayerNorm(48) + self.fc = nn.Linear(48, num_classes) + + def forward(self, x): + x = self.patch_embed(x) + x = self.block1(x) + x = self.block2(x) + x = self.norm(x) + x = x.mean(dim=1) + x = self.fc(x) + return F.log_softmax(x, dim=1) + +# ---------- Training and Testing ---------- + +def train(args, model, device, train_loader, optimizer, epoch): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + if batch_idx % args.log_interval == 0: + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( + epoch, batch_idx * len(data), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), loss.item())) + if args.dry_run: + break + +def test(args, model, device, test_loader): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += F.nll_loss(output, target, reduction='sum').item() + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + if args.dry_run: + break + + test_loss /= len(test_loader.dataset) + print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( + test_loss, correct, len(test_loader.dataset), + 100. * correct / len(test_loader.dataset))) + +# ---------- Main ---------- + +def main(): + parser = argparse.ArgumentParser(description='Swin Transformer CIFAR10 Example') + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument('--test-batch-size', type=int, default=1000) + parser.add_argument('--epochs', type=int, default=10) + parser.add_argument('--lr', type=float, default=0.01) + parser.add_argument('--gamma', type=float, default=0.7) + parser.add_argument('--dry-run', action='store_true') + parser.add_argument('--seed', type=int, default=42) + parser.add_argument('--log-interval', type=int, default=10) + parser.add_argument('--save-model', action='store_true') + args = parser.parse_args() + + use_accel = torch.accelerator.is_available() + device = torch.accelerator.current_accelerator() if use_accel else torch.device("cpu") + print(f"Using device: {device}") + + torch.manual_seed(args.seed) + + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)) + ]) + + train_loader = torch.utils.data.DataLoader( + datasets.CIFAR10('../data', train=True, download=True, transform=transform), + batch_size=args.batch_size, shuffle=True) + + test_loader = torch.utils.data.DataLoader( + datasets.CIFAR10('../data', train=False, transform=transform), + batch_size=args.test_batch_size, shuffle=False) + + model = SwinTinyNet().to(device) + optimizer = optim.Adam(model.parameters(), lr=args.lr) + scheduler = StepLR(optimizer, step_size=3, gamma=args.gamma) + + for epoch in range(1, args.epochs + 1): + train(args, model, device, train_loader, optimizer, epoch) + test(args, model, device, test_loader) + scheduler.step() + + if args.save_model: + torch.save(model.state_dict(), "swin_cifar10.pt") +main() \ No newline at end of file
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: