This repository is the official implementation of Detecting Errors and Estimating Accuracy on Unlabeled Data with Self-training Ensembles.
- It is tested under Ubuntu Linux 16.04.1 and Python 3.6 environment, and requires some packages to be installed: PyTorch, numpy and scikit-learn.
- To install requirements:
pip install -r requirements.txt
- MNIST-M: download it from the Google drive. Extract the files and place them in
./dataset/mnist_m/
. - SVHN: need to download Format 2 data (
*.mat
). Place the files in./dataset/svhn/
. - USPS: download the usps.h5 file. Place the file in
./dataset/usps/
.
train_model.py
: train standard models via supervised learning.train_dann.py
: train domain adaptive (DANN) models.eval_pipeline.py
: evaluate various methods on all tasks.
- To train a standard model via supervised learning, you can use the following command:
python train_model.py --source-dataset {source dataset} --model-type {model type} --base-dir {directory to save the model}
{source dataset}
can be mnist
, mnist-m
, svhn
or usps
.
{model type}
can be typical_dnn
or dann_arch
.
- To train a domain adaptive (DANN) model, you can use the following command:
python train_dann.py --source-dataset {source dataset} --target-dataset {target dataset} --base-dir {directory to save the model} [--test-time]
{source dataset}
(or {target dataset}
) can be mnist
, mnist-m
, svhn
or usps
.
The argument --test-time
is to indicate whether to replace the target training dataset with the target test dataset.
- To evaluate a method on all training-test dataset pairs, you can use the following command:
python eval_pipeline.py --model-type {model type} --method {method}
{model type}
can be typical_dnn
or dann_arch
.
{method}
can be conf_avg
, ensemble_conf_avg
, conf
, trust_score
, proxy_risk
, our_ri
or our_rm
.
You can run the following scrips to pre-train all models needed for the experiments.
run_all_model_training.sh
: train all supervised learning models.run_all_dann_training.sh
: train all DANN models.run_all_ensemble_training.sh
: train all ensemble models.
You can run the following script to get the results reported in the paper.
run_all_evaluation.sh
: evaluate all methods on all tasks.
We provide pre-trained models produced by our training scripts: run_all_model_training.sh
, run_all_dann_training.sh
and run_all_ensemble_training.sh
.
You can download the pre-trained models from Google Drive.
Part of this code is inspired by estimating-generalization and TrustScore.
Please cite our work if you use the codebase:
@article{chen2021detecting,
title={Detecting Errors and Estimating Accuracy on Unlabeled Data with Self-training Ensembles},
author={Chen, Jiefeng and Liu, Frederick and Avci, Besim and Wu, Xi and Liang, Yingyu and Jha, Somesh},
journal={Advances in Neural Information Processing Systems},
volume={34},
year={2021}
}
Please refer to the LICENSE.