diff --git a/imagenet/main.py b/imagenet/main.py index cc32d50733..a7d01139da 100644 --- a/imagenet/main.py +++ b/imagenet/main.py @@ -1,16 +1,22 @@ import argparse +import datetime +import logging import os import random import shutil import time import warnings from enum import Enum +from typing import get_type_hints, Tuple, List, Union, Dict, NamedTuple +import matplotlib.pyplot as plt +import numpy as np import torch import torch.backends.cudnn as cudnn import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn +import torch.nn.functional as F import torch.nn.parallel import torch.optim import torch.utils.data @@ -18,12 +24,132 @@ import torchvision.datasets as datasets import torchvision.models as models import torchvision.transforms as transforms +from sklearn.metrics import f1_score, precision_score, recall_score, balanced_accuracy_score, confusion_matrix from torch.optim.lr_scheduler import StepLR from torch.utils.data import Subset +from torch.utils.tensorboard import SummaryWriter +from torchmetrics.classification import MulticlassPrecisionRecallCurve model_names = sorted(name for name in models.__dict__ - if name.islower() and not name.startswith("__") - and callable(models.__dict__[name])) + if name.islower() and not name.startswith("__") + and callable(models.__dict__[name])) + + +# ================ +# Model evaluation +# ================ + +class TrainMetrics(NamedTuple): + class_labels: List[int] + acc_balanced: float + f1_micro: float + f1_macro: float + prec_micro: float + prec_macro: float + rec_micro: float + rec_macro: float + + +class ValidationMetrics(NamedTuple): + class_labels: List[int] + acc_balanced: float + f1_micro: float + f1_macro: float + prec_micro: float + prec_macro: float + rec_micro: float + rec_macro: float + f1_per_class: List[Tuple[int, float]] + conf_matrix: np.array + labels_true: np.array + labels_pred: np.array + labels_probs: np.array + fig_pr_curve_micro: plt.Figure + + +class EarlyStopping: + """ + Based on: + - https://pytorch.org/ignite/_modules/ignite/handlers/early_stopping.html#EarlyStopping + - https://github.com/Bjarten/early-stopping-pytorch + """ + + def __init__(self, patience: int = 3, min_delta: float = 1, min_epochs: int = 50): + self.patience = patience + self.min_delta = min_delta + self.counter = 0 + self.min_validation_loss = float('inf') + self.epoch_min_validation_loss = 0 + self.should_stop = False + self.min_epochs = min_epochs + + def __call__(self, validation_loss, epoch): + if validation_loss < self.min_validation_loss: + self.min_validation_loss = validation_loss + self.epoch_min_validation_loss = epoch + self.counter = 0 + elif validation_loss > (self.min_validation_loss + self.min_delta): + self.counter += 1 + if self.counter >= self.patience and epoch >= self.min_epochs: + self.should_stop = True + + +# ================ +# Custom model & data +# ================ +def safe_import(module_name): + import importlib + import sys + + if module_name in sys.modules: + return sys.modules[module_name] + + try: + module = importlib.import_module(module_name) + return module + except ImportError as e: + print(f'Error importing module {module_name}. Make sure the module exists and can be imported.') + raise e + + +def get_module_method(module_name, method_name, expected_type_hint): + if hasattr(module_name, method_name) and callable(getattr(module_name, method_name)): + method = getattr(module_name, method_name) + if not get_type_hints(method)['return'] == expected_type_hint: + raise Exception( + f'The provided method {module_name}.{method_name} does not respect the ' + f'expected type hint {expected_type_hint}') + return method() + else: + raise Exception(f'The provided module {module_name} does not have method {method_name}') + + +def get_run_name(model, train_dataset, val_dataset, args): + today = datetime.datetime.now().strftime('%m%d-%H%M') + + model_info = model.__class__.__name__ + dataset_info = train_dataset.__class__.__name__ + if args.use_module_definitions: + module = safe_import(args.use_module_definitions.replace('.py', '')) + try: + model_info = get_module_method(module, 'get_model_info', str) + except: + pass + try: + dataset_info = get_module_method(module, 'get_dataset_info', str) + except: + pass + + train_dataset_size = len(train_dataset) + val_dataset_size = len(val_dataset) + + return (f"{today}_{model_info}" + f"_{dataset_info}-train-{train_dataset_size}-val-{val_dataset_size}") + + +# ================ +# Arguments +# ================ parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') parser.add_argument('data', metavar='DIR', nargs='?', default='imagenet', @@ -31,8 +157,18 @@ parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=model_names, help='model architecture: ' + - ' | '.join(model_names) + - ' (default: resnet18)') + ' | '.join(model_names) + + ' (default: resnet18)') +parser.add_argument('-m', '--use-module-definitions', metavar='MODULE', default=None, + help='load a custom py file for the model and/or dataset & loader.' + 'The file can contain the following functions: ' + 'get_model() -> nn.Module' + 'get_train_dataset() -> torch.utils.data.Dataset' + 'get_val_dataset() -> torch.utils.data.Dataset' + 'get_train_loader() -> torch.utils.data.DataLoader' + 'get_val_loader() -> torch.utils.data.DataLoader' + '(default: None)') +parser.add_argument('-tb', '--tb-summary-writer-dir', metavar='SUMMARY_DIR', default=None) parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=90, type=int, metavar='N', @@ -55,6 +191,8 @@ metavar='N', help='print frequency (default: 10)') parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') +parser.add_argument('-ch', '--checkpoints', default='', type=str, + help='path to checkpoints dir (default: none)') parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set') parser.add_argument('--pretrained', dest='pretrained', action='store_true', @@ -78,7 +216,9 @@ 'multi node data parallel training') parser.add_argument('--dummy', action='store_true', help="use fake data to benchmark") +log = print best_acc1 = 0 +best_metrics = ValidationMetrics([], 0, 0, 0, 0, 0, 0, 0, [], [], [], [], [], None) def main(): @@ -107,7 +247,8 @@ def main(): if torch.cuda.is_available(): ngpus_per_node = torch.cuda.device_count() if ngpus_per_node == 1 and args.dist_backend == "nccl": - warnings.warn("nccl backend >=2.5 requires GPU count>1, see https://github.com/NVIDIA/nccl/issues/103 perhaps use 'gloo'") + warnings.warn( + "nccl backend >=2.5 requires GPU count>1, see https://github.com/NVIDIA/nccl/issues/103 perhaps use 'gloo'") else: ngpus_per_node = 1 @@ -124,11 +265,19 @@ def main(): def main_worker(gpu, ngpus_per_node, args): - global best_acc1 + global best_acc1, best_metrics, log args.gpu = gpu + if args.use_module_definitions: + module = safe_import(args.use_module_definitions.replace('.py', '')) + try: + logger = get_module_method(module, 'get_logger', logging.Logger) + log = logger.info + except: + pass + if args.gpu is not None: - print("Use GPU: {} for training".format(args.gpu)) + log("Use GPU: {} for training".format(args.gpu)) if args.distributed: if args.dist_url == "env://" and args.rank == -1: @@ -141,14 +290,18 @@ def main_worker(gpu, ngpus_per_node, args): world_size=args.world_size, rank=args.rank) # create model if args.pretrained: - print("=> using pre-trained model '{}'".format(args.arch)) + log("=> using pre-trained model '{}'".format(args.arch)) model = models.__dict__[args.arch](pretrained=True) else: - print("=> creating model '{}'".format(args.arch)) - model = models.__dict__[args.arch]() + if not args.use_module_definitions: + log("=> creating model '{}'".format(args.arch)) + model = models.__dict__[args.arch]() + else: + module = safe_import(args.use_module_definitions.replace('.py', '')) + model = get_module_method(module, 'get_model', nn.Module) if not torch.cuda.is_available() and not torch.backends.mps.is_available(): - print('using CPU, this will be slow') + log('using CPU, this will be slow') elif args.distributed: # For multiprocessing distributed, DistributedDataParallel constructor # should always set the single device scope, otherwise, @@ -197,14 +350,14 @@ def main_worker(gpu, ngpus_per_node, args): optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) - + """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" scheduler = StepLR(optimizer, step_size=30, gamma=0.1) - + # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): - print("=> loading checkpoint '{}'".format(args.resume)) + log("=> loading checkpoint '{}'".format(args.resume)) if args.gpu is None: checkpoint = torch.load(args.resume) elif torch.cuda.is_available(): @@ -219,40 +372,44 @@ def main_worker(gpu, ngpus_per_node, args): model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) - print("=> loaded checkpoint '{}' (epoch {})" - .format(args.resume, checkpoint['epoch'])) + log("=> loaded checkpoint '{}' (epoch {})" + .format(args.resume, checkpoint['epoch'])) else: - print("=> no checkpoint found at '{}'".format(args.resume)) - + log("=> no checkpoint found at '{}'".format(args.resume)) # Data loading code if args.dummy: - print("=> Dummy data is used!") + log("=> Dummy data is used!") train_dataset = datasets.FakeData(1281167, (3, 224, 224), 1000, transforms.ToTensor()) val_dataset = datasets.FakeData(50000, (3, 224, 224), 1000, transforms.ToTensor()) else: - traindir = os.path.join(args.data, 'train') - valdir = os.path.join(args.data, 'val') - normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]) - - train_dataset = datasets.ImageFolder( - traindir, - transforms.Compose([ - transforms.RandomResizedCrop(224), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - normalize, - ])) - - val_dataset = datasets.ImageFolder( - valdir, - transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - normalize, - ])) + if not args.use_module_definitions: + traindir = os.path.join(args.data, 'train') + valdir = os.path.join(args.data, 'val') + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + train_dataset = datasets.ImageFolder( + traindir, + transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ])) + + val_dataset = datasets.ImageFolder( + valdir, + transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ])) + else: + module = safe_import(args.use_module_definitions.replace('.py', '')) + train_dataset = get_module_method(module, 'get_train_dataset', torch.utils.data.Dataset) + val_dataset = get_module_method(module, 'get_val_dataset', torch.utils.data.Dataset) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) @@ -261,60 +418,178 @@ def main_worker(gpu, ngpus_per_node, args): train_sampler = None val_sampler = None - train_loader = torch.utils.data.DataLoader( - train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), - num_workers=args.workers, pin_memory=True, sampler=train_sampler) + if not args.use_module_definitions: + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), + num_workers=args.workers, pin_memory=True, sampler=train_sampler) - val_loader = torch.utils.data.DataLoader( - val_dataset, batch_size=args.batch_size, shuffle=False, - num_workers=args.workers, pin_memory=True, sampler=val_sampler) + val_loader = torch.utils.data.DataLoader( + val_dataset, batch_size=args.batch_size, shuffle=False, + num_workers=args.workers, pin_memory=True, sampler=val_sampler) + else: + module = safe_import(args.use_module_definitions.replace('.py', '')) + train_loader = get_module_method(module, 'get_train_loader', torch.utils.data.DataLoader) + val_loader = get_module_method(module, 'get_val_loader', torch.utils.data.DataLoader) + + target_class_translations = None + if args.use_module_definitions: + try: + module = safe_import(args.use_module_definitions.replace('.py', '')) + target_class_translations = get_module_method(module, 'target_class_translations', Dict[int, str]) + log(f'Loaded target_class_translations from {args.use_module_definitions}') + except Exception as e: + log(f'Error getting target_class_translations from {args.use_module_definitions}: {e}') + + def get_target_class(cl: int) -> str: + if target_class_translations: + return target_class_translations[cl] + return f"Class-{cl}" if args.evaluate: validate(val_loader, model, criterion, args) return - for epoch in range(args.start_epoch, args.epochs): - if args.distributed: - train_sampler.set_epoch(epoch) - - # train for one epoch - train(train_loader, model, criterion, optimizer, epoch, device, args) - - # evaluate on validation set - acc1 = validate(val_loader, model, criterion, args) - - scheduler.step() - - # remember best acc@1 and save checkpoint - is_best = acc1 > best_acc1 - best_acc1 = max(acc1, best_acc1) - - if not args.multiprocessing_distributed or (args.multiprocessing_distributed - and args.rank % ngpus_per_node == 0): - save_checkpoint({ - 'epoch': epoch + 1, - 'arch': args.arch, - 'state_dict': model.state_dict(), - 'best_acc1': best_acc1, - 'optimizer' : optimizer.state_dict(), - 'scheduler' : scheduler.state_dict() - }, is_best) - - -def train(train_loader, model, criterion, optimizer, epoch, device, args): + run_name = get_run_name(model, train_dataset, val_dataset, args) + tensorboard_writer = None + if args.tb_summary_writer_dir: + tb_log_dir_path = os.path.join(args.tb_summary_writer_dir, run_name) + tensorboard_writer = SummaryWriter(tb_log_dir_path) + log(f'TensorBoard summary writer is created at {tb_log_dir_path}') + + try: + model.eval() + with torch.no_grad(): + images, _ = next(iter(train_loader)) + if args.gpu is not None and torch.cuda.is_available(): + images = images.cuda(args.gpu, non_blocking=True) + if torch.backends.mps.is_available(): + images = images.to('mps') + tensorboard_writer.add_graph(model, images) + except Exception as e: + log(f"Failed to add graph to tensorboard.") + + early_stopping = EarlyStopping(patience=5, min_delta=0.5, min_epochs=50) + try: + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + train_sampler.set_epoch(epoch) + + # train for one epoch + train_acc1, train_loss, train_metrics = train(train_loader, model, criterion, optimizer, epoch, device, + args) + + # evaluate on validation set + val_acc1, val_loss, val_metrics = validate(val_loader, model, criterion, args) + scheduler.step() + early_stopping(val_loss, epoch) + + # remember best acc@1 and save checkpoint + is_best = val_acc1 > best_acc1 + best_acc1 = max(val_acc1, best_acc1) + best_metrics = val_metrics if val_metrics.f1_micro > best_metrics.f1_micro else best_metrics + + if not args.multiprocessing_distributed or \ + (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0) or \ + epoch == args.epochs - 1: + save_checkpoint({ + 'epoch': epoch + 1, + 'arch': args.arch, + 'state_dict': model.state_dict(), + 'best_acc1': best_acc1, + 'optimizer': optimizer.state_dict(), + 'scheduler': scheduler.state_dict() + }, is_best, run_name, args.checkpoints) + + if tensorboard_writer: + tensorboard_writer.add_scalars('Loss', dict(train=train_loss, val=val_loss), epoch + 1) + tensorboard_writer.add_scalars('Metrics/Accuracy', + dict(val_acc=val_acc1 / 100.0, + val_bacc=val_metrics.acc_balanced, + train_acc=train_acc1 / 100.0, + train_bacc=train_metrics.acc_balanced), + epoch + 1) + tensorboard_writer.add_scalars('Metrics/F1', + dict(val_micro=val_metrics.f1_micro, + val_macro=val_metrics.f1_macro, + train_micro=train_metrics.f1_micro, + train_macro=train_metrics.f1_macro), + epoch + 1) + tensorboard_writer.add_scalars('Metrics/Precision', + dict(val_micro=val_metrics.prec_micro, + val_macro=val_metrics.prec_macro, + train_micro=train_metrics.prec_micro, + train_macro=train_metrics.prec_macro), + epoch + 1) + tensorboard_writer.add_scalars('Metrics/Recall', + dict(val_micro=val_metrics.rec_micro, + val_macro=val_metrics.rec_macro, + train_micro=train_metrics.rec_micro, + train_macro=train_metrics.rec_macro), + epoch + 1) + tensorboard_writer.add_scalars('Metrics/F1/class', + {get_target_class(cl): f1 for cl, f1 in val_metrics.f1_per_class}, + epoch + 1) + + if epoch < 10 or epoch % 5 == 0 or epoch == args.epochs - 1: + class_names = [get_target_class(cl) for cl in list({l for l in val_metrics.class_labels})] + fig_abs, _ = plot_confusion_matrix(val_metrics.conf_matrix, class_names=class_names, + normalize=False) + fig_rel, _ = plot_confusion_matrix(val_metrics.conf_matrix, class_names=class_names, normalize=True) + tensorboard_writer.add_figure('Confusion matrix', fig_abs, epoch + 1) + tensorboard_writer.add_figure('Confusion matrix/normalized', fig_rel, epoch + 1) + + for cl in val_metrics.class_labels: + class_index = int(cl) + labels_true = val_metrics.labels_true == class_index + pred_probs = val_metrics.labels_probs[:, class_index] + tensorboard_writer.add_pr_curve(f'PR curve/{get_target_class(class_index)}', + labels_true, pred_probs, epoch + 1) + + tensorboard_writer.add_figure('PR curve', val_metrics.fig_pr_curve_micro, epoch + 1) + + if early_stopping.should_stop: + log(f"Early stopping at epoch {epoch + 1}") + break + + + except KeyboardInterrupt: + log('Training interrupted, saving hparams to TensorBoard...') + finally: + if args.use_module_definitions: + module = safe_import(args.use_module_definitions.replace('.py', '')) + hparams = get_module_method(module, 'get_hparams', Dict[str, Union[int, float, bool, str]]) + if tensorboard_writer and hparams: + tensorboard_writer.add_hparams(hparams, { + 'hparams/Accuracy': best_acc1 / 100.0, + 'hparams/F1-micro': best_metrics.f1_micro, + 'hparams/F1-macro': best_metrics.f1_macro, + 'hparams/P-micro': best_metrics.prec_micro, + 'hparams/P-macro': best_metrics.prec_macro, + 'hparams/R-micro': best_metrics.rec_micro, + 'hparams/R-macro': best_metrics.rec_macro, + }) + + +def train(train_loader, model, criterion, optimizer, epoch, device, args) -> Tuple[float, float, TrainMetrics]: batch_time = AverageMeter('Time', ':6.3f') data_time = AverageMeter('Data', ':6.3f') losses = AverageMeter('Loss', ':.4e') - top1 = AverageMeter('Acc@1', ':6.2f') - top5 = AverageMeter('Acc@5', ':6.2f') + acc_top1 = AverageMeter('Acc@1', ':6.2f') + acc_top5 = AverageMeter('Acc@5', ':6.2f') + progress = ProgressMeter( len(train_loader), - [batch_time, data_time, losses, top1, top5], + [batch_time, data_time, losses, acc_top1, acc_top5], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() + # for train metrics + labels_true = np.array([], dtype=np.int64) + labels_pred = np.array([], dtype=np.int64) + labels_probs = [] + end = time.time() for i, (images, target) in enumerate(train_loader): # measure data loading time @@ -331,14 +606,22 @@ def train(train_loader, model, criterion, optimizer, epoch, device, args): # measure accuracy and record loss acc1, acc5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), images.size(0)) - top1.update(acc1[0], images.size(0)) - top5.update(acc5[0], images.size(0)) + acc_top1.update(acc1[0], images.size(0)) + acc_top5.update(acc5[0], images.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() + with torch.no_grad(): + predicted_values, predicted_indices = torch.max(output.data, 1) + labels_true = np.append(labels_true, target.cpu().numpy()) + labels_pred = np.append(labels_pred, predicted_indices.cpu().numpy()) + + class_probs_batch = [F.softmax(el, dim=0) for el in output] + labels_probs.append(class_probs_batch) + # measure elapsed time batch_time.update(time.time() - end) end = time.time() @@ -346,10 +629,22 @@ def train(train_loader, model, criterion, optimizer, epoch, device, args): if i % args.print_freq == 0: progress.display(i + 1) + if args.distributed: + acc_top1.all_reduce() + acc_top5.all_reduce() + + labels_probs = torch.cat([torch.stack(batch) for batch in labels_probs]).cpu() + metrics = calculate_train_metrics(labels_true, labels_pred, labels_probs) + + return acc_top1.avg, loss.item(), metrics -def validate(val_loader, model, criterion, args): - def run_validate(loader, base_progress=0): +def validate(val_loader, model, criterion, args) -> Tuple[float, float, "ValidationMetrics"]: + def run_validate(loader, base_progress=0) -> ValidationMetrics: + labels_true = np.array([], dtype=np.int64) + labels_pred = np.array([], dtype=np.int64) + labels_probs = [] + with torch.no_grad(): end = time.time() for i, (images, target) in enumerate(loader): @@ -369,8 +664,17 @@ def run_validate(loader, base_progress=0): # measure accuracy and record loss acc1, acc5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), images.size(0)) - top1.update(acc1[0], images.size(0)) - top5.update(acc5[0], images.size(0)) + acc_top1.update(acc1[0], images.size(0)) + acc_top5.update(acc5[0], images.size(0)) + + # measure f1, precision, recall + with torch.no_grad(): + predicted_values, predicted_indices = torch.max(output.data, 1) + labels_true = np.append(labels_true, target.cpu().numpy()) + labels_pred = np.append(labels_pred, predicted_indices.cpu().numpy()) + + class_probs_batch = [F.softmax(el, dim=0) for el in output] + labels_probs.append(class_probs_batch) # measure elapsed time batch_time.update(time.time() - end) @@ -379,22 +683,28 @@ def run_validate(loader, base_progress=0): if i % args.print_freq == 0: progress.display(i + 1) + labels_probs = torch.cat([torch.stack(batch) for batch in labels_probs]).cpu() + + return calculate_validation_metrics(labels_true, labels_pred, labels_probs) + batch_time = AverageMeter('Time', ':6.3f', Summary.NONE) losses = AverageMeter('Loss', ':.4e', Summary.NONE) - top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE) - top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE) + acc_top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE) + acc_top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE) + progress = ProgressMeter( len(val_loader) + (args.distributed and (len(val_loader.sampler) * args.world_size < len(val_loader.dataset))), - [batch_time, losses, top1, top5], + [batch_time, losses, acc_top1, acc_top5], prefix='Test: ') # switch to evaluate mode model.eval() - run_validate(val_loader) + metrics = run_validate(val_loader) if args.distributed: - top1.all_reduce() - top5.all_reduce() + acc_top1.all_reduce() + acc_top5.all_reduce() + losses.all_reduce() if args.distributed and (len(val_loader.sampler) * args.world_size < len(val_loader.dataset)): aux_val_dataset = Subset(val_loader.dataset, @@ -402,17 +712,21 @@ def run_validate(loader, base_progress=0): aux_val_loader = torch.utils.data.DataLoader( aux_val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) - run_validate(aux_val_loader, len(val_loader)) + metrics = run_validate(aux_val_loader, len(val_loader)) progress.display_summary() - return top1.avg + return acc_top1.avg, losses.avg, metrics -def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): - torch.save(state, filename) +def save_checkpoint(state, is_best, run_info: str = "", dir="./"): + filename = f'{run_info}_checkpoint.pth.tar' + filepath = os.path.join(dir, filename) + log(f'Saving checkpoint to {filename} at {filepath}') + torch.save(state, filepath) if is_best: - shutil.copyfile(filename, 'model_best.pth.tar') + shutil.copyfile(filepath, filepath.replace("checkpoint", "model_best")) + class Summary(Enum): NONE = 0 @@ -420,8 +734,15 @@ class Summary(Enum): SUM = 2 COUNT = 3 + class AverageMeter(object): """Computes and stores the average and current value""" + + val: float + sum: float + count: int + avg: float + def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE): self.name = name self.fmt = fmt @@ -455,7 +776,7 @@ def all_reduce(self): def __str__(self): fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' return fmtstr.format(**self.__dict__) - + def summary(self): fmtstr = '' if self.summary_type is Summary.NONE: @@ -468,12 +789,12 @@ def summary(self): fmtstr = '{name} {count:.3f}' else: raise ValueError('invalid summary type %r' % self.summary_type) - + return fmtstr.format(**self.__dict__) class ProgressMeter(object): - def __init__(self, num_batches, meters, prefix=""): + def __init__(self, num_batches, meters: List[AverageMeter], prefix=""): self.batch_fmtstr = self._get_batch_fmtstr(num_batches) self.meters = meters self.prefix = prefix @@ -481,18 +802,24 @@ def __init__(self, num_batches, meters, prefix=""): def display(self, batch): entries = [self.prefix + self.batch_fmtstr.format(batch)] entries += [str(meter) for meter in self.meters] - print('\t'.join(entries)) - + log('\t'.join(entries)) + def display_summary(self): entries = [" *"] entries += [meter.summary() for meter in self.meters] - print(' '.join(entries)) + log(' '.join(entries)) def _get_batch_fmtstr(self, num_batches): num_digits = len(str(num_batches // 1)) fmt = '{:' + str(num_digits) + 'd}' return '[' + fmt + '/' + fmt.format(num_batches) + ']' + +# ================= +# Metrics +# ================= + + def accuracy(output, target, topk=(1,)): """Computes the accuracy over the k top predictions for the specified values of k""" with torch.no_grad(): @@ -502,7 +829,6 @@ def accuracy(output, target, topk=(1,)): _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) - res = [] for k in topk: correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) @@ -510,5 +836,91 @@ def accuracy(output, target, topk=(1,)): return res +def calculate_train_metrics(labels_true: np.array, labels_pred: np.array, + labels_probs: torch.Tensor) -> TrainMetrics: + unique_labels = list({l for l in labels_true}) + f1_micro = f1_score(labels_true, labels_pred, average="micro") + f1_macro = f1_score(labels_true, labels_pred, average="macro") + + acc_balanced = balanced_accuracy_score(labels_true, labels_pred) + prec_micro = precision_score(labels_true, labels_pred, average="micro") + prec_macro = precision_score(labels_true, labels_pred, average="macro") + rec_micro = recall_score(labels_true, labels_pred, average="micro") + rec_macro = recall_score(labels_true, labels_pred, average="macro") + + return TrainMetrics( + unique_labels, + acc_balanced, + f1_micro, f1_macro, + prec_micro, prec_macro, + rec_micro, rec_macro + ) + + +def calculate_validation_metrics(labels_true: np.array, labels_pred: np.array, + labels_probs: torch.Tensor) -> ValidationMetrics: + unique_labels = list({l for l in labels_true}) + f1_per_class = f1_score(labels_true, labels_pred, average=None, labels=unique_labels) + f1_micro = f1_score(labels_true, labels_pred, average="micro") + f1_macro = f1_score(labels_true, labels_pred, average="macro") + + acc_balanced = balanced_accuracy_score(labels_true, labels_pred) + prec_micro = precision_score(labels_true, labels_pred, average="micro") + prec_macro = precision_score(labels_true, labels_pred, average="macro") + rec_micro = recall_score(labels_true, labels_pred, average="micro") + rec_macro = recall_score(labels_true, labels_pred, average="macro") + + conf_matrix = confusion_matrix(labels_true, labels_pred) + + fig_pr_curve_micro, _ = plot_pr_curve_micro(len(unique_labels), labels_probs, torch.tensor(labels_true)) + + return ValidationMetrics( + unique_labels, + acc_balanced, + f1_micro, f1_macro, + prec_micro, prec_macro, + rec_micro, rec_macro, + [(cl, f1) for cl, f1 in zip(unique_labels, f1_per_class)], + conf_matrix, + labels_true, + labels_pred, + labels_probs, + fig_pr_curve_micro + ) + + +def plot_confusion_matrix(cm, class_names, normalize=False): + plt.switch_backend('agg') + fig = plt.figure(figsize=(10, 10)) + + if normalize: + cm = np.around(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis], decimals=2) + colormap = "Greens" + else: + colormap = "Blues" + + plt.imshow(cm, interpolation='nearest', cmap=colormap) + plt.title('Confusion matrix') + plt.colorbar() + tick_marks = np.arange(len(class_names)) + plt.xticks(tick_marks, class_names, rotation=90) + plt.yticks(tick_marks, class_names) + + plt.tight_layout() + plt.ylabel('True label') + plt.xlabel('Predicted label') + return fig, plt + + +def plot_pr_curve_micro(num_classes: int, labels_probs: torch.Tensor, labels_true: torch.Tensor): + fig = plt.figure(figsize=(8, 8)) + metric = MulticlassPrecisionRecallCurve(num_classes=num_classes, average="micro") + metric.update(labels_probs, labels_true) + metric.plot(ax=plt.gca()) + plt.title("PR curve micro avg.") + plt.tight_layout() + return fig, plt + + if __name__ == '__main__': main() pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy