Skip to content

chariako/asynchronous_federated_learning_simulation

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Simulate the Asynchronous Federated Training of NN-based Classifiers

Overview

This project provides a serial framework for simulating the asynchronous federated training of neural network (NN)-based classifiers on various standard datasets. Currently, the following datasets and models are supported:

  • MNIST with a simple CNN
  • CIFAR-10 with ResNet-18

Asynchronous Federated Learning

The framework simulates the following implementation of asynchronous federated learning, best suited for cross-silo settings:

  1. The server initializes the global model and broadcasts it to all participating clients.
  2. Clients independently train the global model using their local data. Once a client completes its training, it sends its update to the server, requests, and receives the latest version of the global model.
  3. Upon receiving the global model, clients repeat step 2.
  4. The server periodically updates the global model after receiving a predefined number of local updates (buffered asynchronous aggregation, e.g., FedBuff (https://arxiv.org/abs/2106.06639)).

Supported Training Modes

The following federated training modes are supported:

  • Asynchronous modes:
    • Clients asynchronously update the server with local pseudo-gradients on the global model.
    • Clients asynchronously update the server with updates corrected using the scheme described in https://arxiv.org/abs/2405.10123 to balance heterogeneous client update frequencies.
  • Synchronous modes:
    • FedAvg (https://arxiv.org/abs/1602.05629): At each global update, the server uniformly samples a subset of clients and sends them the global model. Sampled clients synchronously update the server with their local pseudo-gradients on the global model.

Client update model

The interval $t_i$ between consecutive updates from client $i \in \{1,...,n\}$ is modeled as an exponential random variable $t_i \sim \text{Exp}(\lambda_i)$. Given a user-specified standard deviation parameter $\sigma>0$, client rates are generated as samples from a log-normal distribution with mean $\mu=0$, i.e., $\lambda_i \sim \text{Log-normal}(0, \sigma^2)$.

Usage

To run the main.py script, use the following command format:

python main.py --args <args>

Arguments

  • --num_clients: Specifies the number of clients participating in federated training. (Type: int, Default: 10)
  • --dataset: Indicates the dataset the be used for training. Options are MNIST or CIFAR-10. (Type: str, Default: mnist)
  • --train_batch_size: Sets the batch size for local training at each client. (Type: int, Default: 64)
  • --test_batch_size: Defines the batch size for evaluating loss and accuracy on the test data. (Type: int, Default: 32)
  • --Delta: Determines the number of local updates required for a global aggregation, or the number of (uniformly) sampled clients for FedAvg. (Type: int, Default: 3)
  • --lr: Specifies the learning rate for local training. (Type: float, Default: 0.01)
  • --num_local_steps: Sets the number of local stochastic gradient descent (SGD) steps for training at each client. (Type: int, Default: 100)
  • --dirichlet_alpha: Controls the heterogeneity among client datasets using a Dirichlet distribution sample. Smaller values yield more heterogeneous datasets. (Type: float, Default: 1.0)
  • --mode: Chooses the communication mode for training. Options are sync for FedAvg or async for asynchronous training. (Type: str, Default: async)
  • --correction: Enables the correction scheme described in https://arxiv.org/abs/2405.10123 to balance heterogeneous client update rates. Set True to activate. (Type: bool, Default: default=False)
  • --client_rate_std: Specifies the standard deviation used for generating client update rates. (Type: float, Default: 0.1)
  • --T_train: Sets the total training time in time units. (Type: float, Default: 5.0)

About

This project simulates the asynchronous federated training of classifiers on standard datasets.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy