This project provides a serial framework for simulating the asynchronous federated training of neural network (NN)-based classifiers on various standard datasets. Currently, the following datasets and models are supported:
- MNIST with a simple CNN
- CIFAR-10 with ResNet-18
The framework simulates the following implementation of asynchronous federated learning, best suited for cross-silo settings:
- The server initializes the global model and broadcasts it to all participating clients.
- Clients independently train the global model using their local data. Once a client completes its training, it sends its update to the server, requests, and receives the latest version of the global model.
- Upon receiving the global model, clients repeat step 2.
- The server periodically updates the global model after receiving a predefined number of local updates (buffered asynchronous aggregation, e.g., FedBuff (https://arxiv.org/abs/2106.06639)).
The following federated training modes are supported:
- Asynchronous modes:
- Clients asynchronously update the server with local pseudo-gradients on the global model.
- Clients asynchronously update the server with updates corrected using the scheme described in https://arxiv.org/abs/2405.10123 to balance heterogeneous client update frequencies.
- Synchronous modes:
- FedAvg (https://arxiv.org/abs/1602.05629): At each global update, the server uniformly samples a subset of clients and sends them the global model. Sampled clients synchronously update the server with their local pseudo-gradients on the global model.
The interval
To run the main.py
script, use the following command format:
python main.py --args <args>
--num_clients
: Specifies the number of clients participating in federated training. (Type:int
, Default:10
)--dataset
: Indicates the dataset the be used for training. Options areMNIST
orCIFAR-10
. (Type:str
, Default:mnist
)--train_batch_size
: Sets the batch size for local training at each client. (Type:int
, Default:64
)--test_batch_size
: Defines the batch size for evaluating loss and accuracy on the test data. (Type:int
, Default:32
)--Delta
: Determines the number of local updates required for a global aggregation, or the number of (uniformly) sampled clients for FedAvg. (Type:int
, Default:3
)--lr
: Specifies the learning rate for local training. (Type:float
, Default:0.01
)--num_local_steps
: Sets the number of local stochastic gradient descent (SGD) steps for training at each client. (Type:int
, Default:100
)--dirichlet_alpha
: Controls the heterogeneity among client datasets using a Dirichlet distribution sample. Smaller values yield more heterogeneous datasets. (Type:float
, Default:1.0
)--mode
: Chooses the communication mode for training. Options aresync
for FedAvg orasync
for asynchronous training. (Type:str
, Default:async
)--correction
: Enables the correction scheme described in https://arxiv.org/abs/2405.10123 to balance heterogeneous client update rates. SetTrue
to activate. (Type:bool
, Default: default=False
)--client_rate_std
: Specifies the standard deviation used for generating client update rates. (Type:float
, Default:0.1
)--T_train
: Sets the total training time in time units. (Type:float
, Default:5.0
)