Skip to content

Add some new flags: use_module_definitions to load custom model, dataset and loaders ; checkpoints ; add f1,prec,recall calc. #1237

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Get module_info from -m
  • Loading branch information
attilamester committed Mar 6, 2024
commit 707e492ed26bceb7b1fdb38a86945f5d227e4807
14 changes: 11 additions & 3 deletions imagenet/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,17 @@ def get_module_method(module_name, method_name, expected_type_hint):
raise Exception(f'The provided module {module_name} does not have method {method_name}')


def get_run_name(model, train_dataset, val_dataset):
def get_run_name(model, train_dataset, val_dataset, args):
today = datetime.datetime.now().strftime('%m%d-%H%M')
model_info = model.get_info() if callable(getattr(model, "get_info", None)) else model.__class__.__name__

model_info = model.__class__.__name__
if args.use_module_definitions:
module = safe_import(args.use_module_definitions.replace('.py', ''))
try:
model_info = get_module_method(module, 'get_model_info', str)
except:
pass

train_dataset_info = train_dataset.get_info() if callable(
getattr(train_dataset, "get_info", None)) else train_dataset.__class__.__name__
train_dataset_size = len(train_dataset)
Expand Down Expand Up @@ -366,7 +374,7 @@ def get_target_class(cl: int) -> str:
validate(val_loader, model, criterion, args)
return

run_name = get_run_name(model, train_dataset, val_dataset)
run_name = get_run_name(model, train_dataset, val_dataset, args)
tensorboard_writer = None
if args.tb_summary_writer_dir:
tb_log_dir_path = os.path.join(args.tb_summary_writer_dir, run_name)
Expand Down
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy