Skip to content

vfdev-5/FixMatch-pytorch

Repository files navigation

FixMatch experiments in PyTorch and Ignite

Experiments with "FixMatch" on Cifar10 dataset.

Based on "FixMatch: Simplifying Semi-Supervised Learning withConsistency and Confidence" and its official code.

Data-augmentations policy is CTA

Online logging on W&B: https://app.wandb.ai/vfdev-5/fixmatch-pytorch

Requirements

pip install --upgrade --pre hydra-core tensorboardX
pip install --upgrade git+https://github.com/pytorch/ignite
# pip install --upgrade --pre pytorch-ignite

Optionally, we can install wandb for online experiments tracking.

pip install wandb

We can also opt to replace Pillow by Pillow-SIMD to accelerate image processing part:

pip uninstall -y pillow && CC="cc -mavx2" pip install --no-cache-dir --force-reinstall pillow-simd

Training

python -u main_fixmatch.py model=WRN-28-2
  • Default output folder: "/tmp/output-fixmatch-cifar10".
  • For complete list of options: python -u main_fixmatch.py --help

This script automatically trains on multiple GPUs (torch.nn.DistributedParallel).

If it is needed to specify input/output folder :

python -u main_fixmatch.py dataflow.data_path=/data/cifar10/ hydra.run.dir=/output-fixmatch model=WRN-28-2

To use wandb logger, we need login and run with online_exp_tracking.wandb=true:

wandb login <token>
python -u main_fixmatch.py model=WRN-28-2 online_exp_tracking.wandb=true

To see other options:

python -u main_fixmatch.py --help

Training curves visualization

By default, we use Tensorboard to log training curves

tensorboard --logdir=/tmp/output-fixmatch-cifar10/

Distributed Data Parallel (DDP) on multiple GPUs (Experimental)

For example, training on 2 GPUs

python -u -m torch.distributed.launch --nproc_per_node=2 main_fixmatch.py model=WRN-28-2 distributed.backend=nccl

TPU(s) on Colab (Experimental)

Open In Colab For example, training on 8 TPUs in distributed mode:

python -u main_fixmatch.py model=resnet18 distributed.backend=xla-tpu distributed.nproc_per_node=8
# or python -u main_fixmatch.py model=WRN-28-2 distributed.backend=xla-tpu distributed.nproc_per_node=8

Experimentations

Faster Resnet-18 training

  • reduced the number of epochs
  • reduced the number of CTA updates
  • reduced EMA decay
python main_fixmatch.py distributed.backend=nccl online_exp_tracking.wandb=true solver.num_epochs=500 \
    ssl.confidence_threshold=0.8 ema_decay=0.9 ssl.cta_update_every=15

About

Implementation of FixMatch in PyTorch and experimentations

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy