Hyunji

Upload new file

"""entry point for training a classifier"""
import argparse
import importlib
import json
import logging
import os
import pprint
import sys
import dill
import torch
import wandb
from box import Box
from torch.utils.data import DataLoader
from lib.base_trainer import Trainer
from lib.utils import logging as logging_utils, os as os_utils, optimizer as optimizer_utils
from src.common.dataset import get_dataset
def parser_setup():
# define argparsers
parser = argparse.ArgumentParser()
parser.add_argument("-D", "--debug", action='store_true')
parser.add_argument("--config", "-c", required=False)
parser.add_argument("--seed", required=False, type=int)
str2bool = os_utils.str2bool
listorstr = os_utils.listorstr
parser.add_argument("--wandb.use", required=False, type=str2bool, default=False)
parser.add_argument("--wandb.run_id", required=False, type=str)
parser.add_argument("--wandb.watch", required=False, type=str2bool, default=False)
parser.add_argument("--project", required=False, type=str, default="brain-age")
parser.add_argument("--exp_name", required=True)
parser.add_argument("--device", required=False,
default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--result_folder", "-r", required=False)
parser.add_argument("--mode", required=False, nargs="+", choices=["test", "train"],
default=["test", "train"])
parser.add_argument("--statefile", "-s", required=False, default=None)
parser.add_argument("--data.name", "-d", required=False, choices=["brain_age"])
# brain_age related arguments
parser.add_argument("--data.root_path", default=None, type=str)
parser.add_argument("--data.train_csv", default=None, type=str)
parser.add_argument("--data.valid_csv", default=None, type=str)
parser.add_argument("--data.test_csv", default=None, type=str)
parser.add_argument("--data.feat_csv", default=None, type=str)
parser.add_argument("--data.train_num_sample", default=-1, type=int,
help="control number of training samples")
parser.add_argument("--data.frame_dim", default=1, type=int, choices=[1, 2, 3],
help="choose which dimension we want to slice, 1 for sagittal, "
"2 for coronal, 3 for axial")
parser.add_argument("--data.frame_keep_style", default="random", type=str,
choices=["random", "ordered"],
help="style of keeping frames when frame_keep_fraction < 1")
parser.add_argument("--data.frame_keep_fraction", default=0, type=float,
help="fraction of frame to keep (usually used during testing with missing "
"frames)")
parser.add_argument("--data.impute", default="drop", type=str,
choices=["drop", "fill", "zeros", "noise"])
parser.add_argument("--model.name", required=False, choices=["regression"])
parser.add_argument("--model.arch.file", required=False, type=str, default=None)
parser.add_argument("--model.arch.lstm_feat_dim", required=False, type=int, default=2)
parser.add_argument("--model.arch.lstm_latent_dim", required=False, type=int, default=128)
parser.add_argument("--model.arch.attn_num_heads", required=False, type=int, default=2)
parser.add_argument("--model.arch.attn_dim", required=False, type=int, default=128)
parser.add_argument("--model.arch.attn_drop", required=False, type=str2bool, default=False)
parser.add_argument("--model.arch.agg_fn", required=False, type=str,
choices=["mean", "max", "attention"])
parser.add_argument("--train.batch_size", required=False, type=int, default=128)
parser.add_argument("--train.patience", required=False, type=int, default=20)
parser.add_argument("--train.max_epoch", required=False, type=int, default=100)
parser.add_argument("--train.optimizer", required=False, type=str, default="adam",
choices=["adam", "sgd"])
parser.add_argument("--train.lr", required=False, type=float, default=1e-3)
parser.add_argument("--train.weight_decay", required=False, type=float, default=5e-4)
parser.add_argument("--train.gradient_norm_clip", required=False, type=float, default=-1)
parser.add_argument("--train.save_strategy", required=False, nargs="+",
choices=["best", "last", "init", "epoch", "current"],
default=["best"])
parser.add_argument("--train.log_every", required=False, type=int, default=1000)
parser.add_argument("--train.stopping_criteria", required=False, type=str, default="accuracy")
parser.add_argument("--train.stopping_criteria_direction", required=False,
choices=["bigger", "lower"], default="bigger")
parser.add_argument("--train.evaluations", required=False, nargs="*", choices=[])
parser.add_argument("--train.scheduler", required=False, type=str, default=None)
parser.add_argument("--train.scheduler_gamma", required=False, type=float)
parser.add_argument("--train.scheduler_milestones", required=False, nargs="+")
parser.add_argument("--train.scheduler_patience", required=False, type=int)
parser.add_argument("--train.scheduler_step_size", required=False, type=int)
parser.add_argument("--train.scheduler_load_on_reduce", required=False, type=str2bool)
#
parser.add_argument("--test.batch_size", required=False, type=int, default=128)
parser.add_argument("--test.evaluations", required=False, nargs="*", choices=[])
parser.add_argument("--test.eval_model", required=False, type=str,
choices=["best", "last", "current"], default="best")
return parser
if __name__ == "__main__":
# set seeds etc here
torch.backends.cudnn.benchmark = True
# define logger etc
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
logger = logging.getLogger()
parser = parser_setup()
config = os_utils.parse_args(parser)
if config.seed is not None:
os_utils.set_seed(config.seed)
if config.debug:
logger.setLevel(logging.DEBUG)
logger.info("Config:")
logger.info(pprint.pformat(config.to_dict(), indent=4))
# see https://github.com/wandb/client/issues/714
os_utils.safe_makedirs(config.result_folder)
statefile, run_id, result_folder = os_utils.get_state_params(
config.wandb.use, config.wandb.run_id, config.result_folder, config.statefile
)
config.statefile = statefile
config.wandb.run_id = run_id
config.result_folder = result_folder
if statefile is not None:
data = torch.load(open(statefile, "rb"), pickle_module=dill)
epoch = data["epoch"]
if epoch >= config.train.max_epoch:
logger.error("Aleady trained upto max epoch; exiting")
sys.exit()
if config.wandb.use:
wandb.init(
name=config.exp_name if config.exp_name is not None else config.result_folder,
config=config.to_dict(),
project=config.project,
dir=config.result_folder,
resume=config.wandb.run_id,
id=config.wandb.run_id,
sync_tensorboard=True,
)
logger.info(f"Starting wandb with id {wandb.run.id}")
# NOTE: WANDB creates git patch so we probably can get rid of this in future
os_utils.copy_code("src", config.result_folder, replace=True)
json.dump(
config.to_dict(),
open(f"{wandb.run.dir if config.wandb.use else config.result_folder}/config.json", "w")
)
logger.info("Getting data and dataloaders")
data, meta = get_dataset(**config.data, device=config.device)
# num_workers = max(min(os.cpu_count(), 8), 1)
num_workers = os.cpu_count()
logger.info(f"Using {num_workers} workers")
train_loader = DataLoader(data["train"], shuffle=True, batch_size=config.train.batch_size,
num_workers=num_workers)
valid_loader = DataLoader(data["valid"], shuffle=False, batch_size=config.test.batch_size,
num_workers=num_workers)
test_loader = DataLoader(data["test"], shuffle=False, batch_size=config.test.batch_size,
num_workers=num_workers)
logger.info("Getting model")
# load arch module
arch_module = importlib.import_module(config.model.arch.file.replace("/", ".")[:-3])
model_arch = arch_module.get_arch(
input_shape=meta.get("input_shape"), output_size=meta.get("num_class"), **config.model.arch,
slice_dim=config.data.frame_dim
)
# declaring models
if config.model.name in "regression":
from src.models.regression import Regression
model = Regression(**model_arch)
else:
raise Exception("Unknown model")
model.to(config.device)
model.stats()
if config.wandb.use and config.wandb.watch:
wandb.watch(model, log="all")
# declaring trainer
optimizer, scheduler = optimizer_utils.get_optimizer_scheduler(
model,
lr=config.train.lr,
optimizer=config.train.optimizer,
opt_params={
"weight_decay": config.train.get("weight_decay", 1e-4),
"momentum" : config.train.get("optimizer_momentum", 0.9)
},
scheduler=config.train.get("scheduler", None),
scheduler_params={
"gamma" : config.train.get("scheduler_gamma", 0.1),
"milestones" : config.train.get("scheduler_milestones", [100, 200, 300]),
"patience" : config.train.get("scheduler_patience", 100),
"step_size" : config.train.get("scheduler_step_size", 100),
"load_on_reduce": config.train.get("scheduler_load_on_reduce"),
"mode" : "max" if config.train.get(
"stopping_criteria_direction") == "bigger" else "min"
},
)
trainer = Trainer(model, optimizer, scheduler=scheduler, statefile=config.statefile,
result_dir=config.result_folder, log_every=config.train.log_every,
save_strategy=config.train.save_strategy,
patience=config.train.patience,
max_epoch=config.train.max_epoch,
stopping_criteria=config.train.stopping_criteria,
gradient_norm_clip=config.train.gradient_norm_clip,
stopping_criteria_direction=config.train.stopping_criteria_direction,
evaluations=Box({"train": config.train.evaluations,
"test" : config.test.evaluations}))
if "train" in config.mode:
logger.info("starting training")
trainer.train(train_loader, valid_loader)
logger.info("Training done;")
# copy current step and write test results to
step_to_write = trainer.step
step_to_write += 1
if "test" in config.mode and config.test.eval_model == "best":
if os.path.exists(f"{trainer.result_dir}/best_model.pt"):
logger.info("Loading best model")
trainer.load(f"{trainer.result_dir}/best_model.pt")
else:
logger.info("eval_model is best, but best model not found ::: evaling last model")
else:
logger.info("eval model is not best, so skipping loading at end of training")
if "test" in config.mode:
logger.info("evaluating model on test set")
logger.info(f"Model was trained upto {trainer.epoch}")
# copy current step and write test results to
step_to_write = trainer.step
step_to_write += 1
loss, aux_loss = trainer.test(train_loader, test_loader)
logging_utils.loss_logger_helper(loss, aux_loss, writer=trainer.summary_writer,
force_print=True, step=step_to_write,
epoch=trainer.epoch,
log_every=trainer.log_every, string="test",
new_line=True)
loss, aux_loss = trainer.test(train_loader, train_loader)
logging_utils.loss_logger_helper(loss, aux_loss, writer=trainer.summary_writer,
force_print=True, step=step_to_write,
epoch=trainer.epoch,
log_every=trainer.log_every, string="train_eval",
new_line=True)
loss, aux_loss = trainer.test(train_loader, valid_loader)
logging_utils.loss_logger_helper(loss, aux_loss, writer=trainer.summary_writer,
force_print=True, step=step_to_write,
epoch=trainer.epoch,
log_every=trainer.log_every, string="valid_eval",
new_line=True)