Skip to content

Add some new flags: use_module_definitions to load custom model, dataset and loaders ; checkpoints ; add f1,prec,recall calc. #1237

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add -ch / --checkpoints flag for dir + add f1,prec,rec scores + tb lo…
…g of model and hyperparams
  • Loading branch information
attilamester committed Mar 4, 2024
commit 02b9d9b3ac7bb7ba09b96c1dc18f4daa049b3c9f
163 changes: 125 additions & 38 deletions imagenet/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import time
import warnings
from enum import Enum
from typing import get_type_hints
from typing import get_type_hints, Tuple, List, Union, Dict

import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
Expand All @@ -20,6 +21,7 @@
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
from sklearn.metrics import f1_score, precision_score, recall_score
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Subset
from torch.utils.tensorboard import SummaryWriter
Expand Down Expand Up @@ -68,6 +70,8 @@
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('-ch', '--checkpoints', default='', type=str,
help='path to checkpoints dir (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
Expand Down Expand Up @@ -164,6 +168,19 @@ def get_module_method(module_name, method_name, expected_type_hint):
raise Exception(f'The provided module {module_name} does not have method {method_name}')


def get_run_name(model, train_dataset, val_dataset):
today = datetime.datetime.now().strftime("%Y-%m-%d")
model_info = ""
train_dataset_info = len(train_dataset)
val_dataset_info = len(val_dataset)

if callable(getattr(model, "get_info", None)):
model_info = f"-{model.get_info()}"
return (f"{today}_{model.__class__.__name__}{model_info}"
f"_{train_dataset.__class__.__name__}-{train_dataset_info}"
f"_{val_dataset.__class__.__name__}-{val_dataset_info}")


def main_worker(gpu, ngpus_per_node, args):
global best_acc1
args.gpu = gpu
Expand Down Expand Up @@ -323,27 +340,35 @@ def main_worker(gpu, ngpus_per_node, args):
train_loader = get_module_method(module, 'get_train_loader', torch.utils.data.DataLoader)
val_loader = get_module_method(module, 'get_val_loader', torch.utils.data.DataLoader)

target_class_translations = None
if args.use_module_definitions:
try:
module = safe_import(args.use_module_definitions.replace('.py', ''))
target_class_translations = get_module_method(module, 'target_class_translations', Dict[int, str])
print(f'Loaded target_class_translations from {args.use_module_definitions}')
except Exception as e:
print(f'Error getting target_class_translations from {args.use_module_definitions}: {e}')

def get_target_class(cl: int) -> str:
if target_class_translations:
return target_class_translations[cl]
return f"Class-{cl}"

if args.evaluate:
validate(val_loader, model, criterion, args)
return

run_name = get_run_name(model, train_dataset, val_dataset)
tensorboard_writer = None
if args.tb_summary_writer_dir:
today = datetime.datetime.now().strftime("%Y-%m-%d")
model_info = ""
train_dataset_info = len(train_dataset)
val_dataset_info = len(val_dataset)

if callable(getattr(model, "get_info", None)):
model_info = f"-{model.get_info()}"

tb_log_dir_name = (f"{today}_{model.__class__.__name__}{model_info}"
f"_{train_dataset.__class__.__name__}-{train_dataset_info}"
f"_{val_dataset.__class__.__name__}-{val_dataset_info}")
tb_log_dir_path = os.path.join(args.tb_summary_writer_dir, tb_log_dir_name)
tb_log_dir_path = os.path.join(args.tb_summary_writer_dir, run_name)
tensorboard_writer = SummaryWriter(tb_log_dir_path)
print(f'TensorBoard summary writer is created at {tb_log_dir_path}')

if tensorboard_writer:
image, label = next(iter(train_loader))
tensorboard_writer.add_graph(model, image)

for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
Expand All @@ -352,39 +377,52 @@ def main_worker(gpu, ngpus_per_node, args):
train_loss = train(train_loader, model, criterion, optimizer, epoch, device, args)

# evaluate on validation set
acc1 = validate(val_loader, model, criterion, args)
(acc1,
f1_micro, f1_macro,
prec_micro, prec_macro,
rec_micro, rec_macro,
f1_per_class) = validate(val_loader, model, criterion, args)

scheduler.step()

# remember best acc@1 and save checkpoint
is_best = acc1 > best_acc1
best_acc1 = max(acc1, best_acc1)

if not args.multiprocessing_distributed or (args.multiprocessing_distributed
and args.rank % ngpus_per_node == 0):
if not args.multiprocessing_distributed or \
(args.multiprocessing_distributed and args.rank % ngpus_per_node == 0) or \
epoch == args.epochs - 1:
save_checkpoint({
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'best_acc1': best_acc1,
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict()
}, is_best)
}, is_best, run_name, args.checkpoints)

if tensorboard_writer:
tensorboard_writer.add_scalars('Loss', dict(train=train_loss), epoch + 1)
tensorboard_writer.add_scalars('Accuracy', dict(val=acc1), epoch + 1)
tensorboard_writer.add_scalars('F1', dict(micro=f1_micro, macro=f1_macro), epoch + 1)
tensorboard_writer.add_scalars('Precision', dict(micro=prec_micro, macro=prec_macro), epoch + 1)
tensorboard_writer.add_scalars('Recall', dict(micro=rec_micro, macro=rec_macro), epoch + 1)
tensorboard_writer.add_scalars('F1/class', {get_target_class(cl): f1 for cl, f1 in f1_per_class}, epoch + 1)

tensorboard_writer.add_hparams({"param1": 1, "param2": 2},
{"Accuracy": best_acc1})


def train(train_loader, model, criterion, optimizer, epoch, device, args) -> float:
batch_time = AverageMeter('Time', ':6.3f')
data_time = AverageMeter('Data', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
acc_top1 = AverageMeter('Acc@1', ':6.2f')
acc_top5 = AverageMeter('Acc@5', ':6.2f')

progress = ProgressMeter(
len(train_loader),
[batch_time, data_time, losses, top1, top5],
[batch_time, data_time, losses, acc_top1, acc_top5],
prefix="Epoch: [{}]".format(epoch))

# switch to train mode
Expand All @@ -406,8 +444,8 @@ def train(train_loader, model, criterion, optimizer, epoch, device, args) -> flo
# measure accuracy and record loss
acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
acc_top1.update(acc1[0], images.size(0))
acc_top5.update(acc5[0], images.size(0))

# compute gradient and do SGD step
optimizer.zero_grad()
Expand All @@ -424,8 +462,22 @@ def train(train_loader, model, criterion, optimizer, epoch, device, args) -> flo
return loss.item()


def validate(val_loader, model, criterion, args):
def run_validate(loader, base_progress=0):
def validate(val_loader, model, criterion, args) -> Tuple[
float, float, float, float, float, float, float, List[Tuple[int, float]]]:
"""
:return: acc1,
f1_micro, f1_macro,
prec_micro, prec_macro,
rec_micro, rec_macro,
f1_per_class: [(target-index, f1), ]
"""

def run_validate(loader, base_progress=0) -> Tuple[
float, float, float, float, float, float, List[Tuple[int, float]]
]:
labels_true = np.array([])
labels_pred = np.array([])

with torch.no_grad():
end = time.time()
for i, (images, target) in enumerate(loader):
Expand All @@ -445,8 +497,14 @@ def run_validate(loader, base_progress=0):
# measure accuracy and record loss
acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
acc_top1.update(acc1[0], images.size(0))
acc_top5.update(acc5[0], images.size(0))

# measure f1, precision, recall
with torch.no_grad():
predicted_values, predicted_indices = torch.max(output.data, 1)
labels_true = np.append(labels_true, target.numpy())
labels_pred = np.append(labels_pred, predicted_indices.cpu().numpy())

# measure elapsed time
batch_time.update(time.time() - end)
Expand All @@ -455,40 +513,47 @@ def run_validate(loader, base_progress=0):
if i % args.print_freq == 0:
progress.display(i + 1)

return metrics_labels_true_pred(labels_true, labels_pred)

batch_time = AverageMeter('Time', ':6.3f', Summary.NONE)
losses = AverageMeter('Loss', ':.4e', Summary.NONE)
top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE)
top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE)
acc_top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE)
acc_top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE)

progress = ProgressMeter(
len(val_loader) + (args.distributed and (len(val_loader.sampler) * args.world_size < len(val_loader.dataset))),
[batch_time, losses, top1, top5],
[batch_time, losses, acc_top1, acc_top5],
prefix='Test: ')

# switch to evaluate mode
model.eval()

run_validate(val_loader)
metrics = run_validate(val_loader)
if args.distributed:
top1.all_reduce()
top5.all_reduce()
acc_top1.all_reduce()
acc_top5.all_reduce()

if args.distributed and (len(val_loader.sampler) * args.world_size < len(val_loader.dataset)):
aux_val_dataset = Subset(val_loader.dataset,
range(len(val_loader.sampler) * args.world_size, len(val_loader.dataset)))
aux_val_loader = torch.utils.data.DataLoader(
aux_val_dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)
run_validate(aux_val_loader, len(val_loader))
metrics = run_validate(aux_val_loader, len(val_loader))

progress.display_summary()
f1_micro, f1_macro, prec_micro, prec_macro, rec_micro, rec_macro, f1_per_class = metrics

return top1.avg
return acc_top1.avg, f1_micro, f1_macro, prec_micro, prec_macro, rec_micro, rec_macro, f1_per_class


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
torch.save(state, filename)
def save_checkpoint(state, is_best, run_info: str = "", dir="./"):
filename = f'{run_info}_checkpoint.pth.tar'
filepath = os.path.join(dir, filename)
print(f'Saving checkpoint to {filename} at {filepath}')
torch.save(state, filepath)
if is_best:
shutil.copyfile(filename, 'model_best.pth.tar')
shutil.copyfile(filepath, filepath.replace("checkpoint", "model_best"))


class Summary(Enum):
Expand All @@ -501,6 +566,11 @@ class Summary(Enum):
class AverageMeter(object):
"""Computes and stores the average and current value"""

val: float
sum: float
count: int
avg: float

def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE):
self.name = name
self.fmt = fmt
Expand Down Expand Up @@ -552,7 +622,7 @@ def summary(self):


class ProgressMeter(object):
def __init__(self, num_batches, meters, prefix=""):
def __init__(self, num_batches, meters: List[AverageMeter], prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
Expand Down Expand Up @@ -590,5 +660,22 @@ def accuracy(output, target, topk=(1,)):
return res


def metrics_labels_true_pred(labels_true: np.array, labels_pred: np.array) -> Tuple[
float, float, float, float, float, float, List[Tuple[Union[int, str], float]]
]:
unique_labels = list({l for l in labels_true})
f1_per_class = f1_score(labels_true, labels_pred, average=None, labels=unique_labels)
f1_micro = f1_score(labels_true, labels_pred, average="micro")
f1_macro = f1_score(labels_true, labels_pred, average="macro")

prec_micro = precision_score(labels_true, labels_pred, average="micro")
prec_macro = precision_score(labels_true, labels_pred, average="macro")
rec_micro = recall_score(labels_true, labels_pred, average="micro")
rec_macro = recall_score(labels_true, labels_pred, average="macro")

return f1_micro, f1_macro, prec_micro, prec_macro, rec_micro, rec_macro, [(cl, f1) for cl, f1 in
zip(unique_labels, f1_per_class)]


if __name__ == '__main__':
main()
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy