Hyunji

Upload new file

1 +"""entry point for training a classifier"""
2 +import argparse
3 +import importlib
4 +import json
5 +import logging
6 +import os
7 +import pprint
8 +import sys
9 +
10 +import dill
11 +import torch
12 +import wandb
13 +from box import Box
14 +from torch.utils.data import DataLoader
15 +
16 +from lib.base_trainer import Trainer
17 +from lib.utils import logging as logging_utils, os as os_utils, optimizer as optimizer_utils
18 +from src.common.dataset import get_dataset
19 +
20 +
21 +def parser_setup():
22 + # define argparsers
23 + parser = argparse.ArgumentParser()
24 + parser.add_argument("-D", "--debug", action='store_true')
25 + parser.add_argument("--config", "-c", required=False)
26 + parser.add_argument("--seed", required=False, type=int)
27 +
28 + str2bool = os_utils.str2bool
29 + listorstr = os_utils.listorstr
30 + parser.add_argument("--wandb.use", required=False, type=str2bool, default=False)
31 + parser.add_argument("--wandb.run_id", required=False, type=str)
32 + parser.add_argument("--wandb.watch", required=False, type=str2bool, default=False)
33 + parser.add_argument("--project", required=False, type=str, default="brain-age")
34 + parser.add_argument("--exp_name", required=True)
35 +
36 + parser.add_argument("--device", required=False,
37 + default="cuda" if torch.cuda.is_available() else "cpu")
38 + parser.add_argument("--result_folder", "-r", required=False)
39 + parser.add_argument("--mode", required=False, nargs="+", choices=["test", "train"],
40 + default=["test", "train"])
41 + parser.add_argument("--statefile", "-s", required=False, default=None)
42 +
43 + parser.add_argument("--data.name", "-d", required=False, choices=["brain_age"])
44 +
45 + # brain_age related arguments
46 + parser.add_argument("--data.root_path", default=None, type=str)
47 + parser.add_argument("--data.train_csv", default=None, type=str)
48 + parser.add_argument("--data.valid_csv", default=None, type=str)
49 + parser.add_argument("--data.test_csv", default=None, type=str)
50 + parser.add_argument("--data.feat_csv", default=None, type=str)
51 + parser.add_argument("--data.train_num_sample", default=-1, type=int,
52 + help="control number of training samples")
53 + parser.add_argument("--data.frame_dim", default=1, type=int, choices=[1, 2, 3],
54 + help="choose which dimension we want to slice, 1 for sagittal, "
55 + "2 for coronal, 3 for axial")
56 + parser.add_argument("--data.frame_keep_style", default="random", type=str,
57 + choices=["random", "ordered"],
58 + help="style of keeping frames when frame_keep_fraction < 1")
59 + parser.add_argument("--data.frame_keep_fraction", default=0, type=float,
60 + help="fraction of frame to keep (usually used during testing with missing "
61 + "frames)")
62 + parser.add_argument("--data.impute", default="drop", type=str,
63 + choices=["drop", "fill", "zeros", "noise"])
64 +
65 + parser.add_argument("--model.name", required=False, choices=["regression"])
66 +
67 + parser.add_argument("--model.arch.file", required=False, type=str, default=None)
68 + parser.add_argument("--model.arch.lstm_feat_dim", required=False, type=int, default=2)
69 + parser.add_argument("--model.arch.lstm_latent_dim", required=False, type=int, default=128)
70 + parser.add_argument("--model.arch.attn_num_heads", required=False, type=int, default=2)
71 + parser.add_argument("--model.arch.attn_dim", required=False, type=int, default=128)
72 + parser.add_argument("--model.arch.attn_drop", required=False, type=str2bool, default=False)
73 + parser.add_argument("--model.arch.agg_fn", required=False, type=str,
74 + choices=["mean", "max", "attention"])
75 +
76 + parser.add_argument("--train.batch_size", required=False, type=int, default=128)
77 +
78 + parser.add_argument("--train.patience", required=False, type=int, default=20)
79 + parser.add_argument("--train.max_epoch", required=False, type=int, default=100)
80 + parser.add_argument("--train.optimizer", required=False, type=str, default="adam",
81 + choices=["adam", "sgd"])
82 + parser.add_argument("--train.lr", required=False, type=float, default=1e-3)
83 + parser.add_argument("--train.weight_decay", required=False, type=float, default=5e-4)
84 + parser.add_argument("--train.gradient_norm_clip", required=False, type=float, default=-1)
85 +
86 + parser.add_argument("--train.save_strategy", required=False, nargs="+",
87 + choices=["best", "last", "init", "epoch", "current"],
88 + default=["best"])
89 + parser.add_argument("--train.log_every", required=False, type=int, default=1000)
90 +
91 + parser.add_argument("--train.stopping_criteria", required=False, type=str, default="accuracy")
92 + parser.add_argument("--train.stopping_criteria_direction", required=False,
93 + choices=["bigger", "lower"], default="bigger")
94 + parser.add_argument("--train.evaluations", required=False, nargs="*", choices=[])
95 +
96 + parser.add_argument("--train.scheduler", required=False, type=str, default=None)
97 + parser.add_argument("--train.scheduler_gamma", required=False, type=float)
98 + parser.add_argument("--train.scheduler_milestones", required=False, nargs="+")
99 + parser.add_argument("--train.scheduler_patience", required=False, type=int)
100 + parser.add_argument("--train.scheduler_step_size", required=False, type=int)
101 + parser.add_argument("--train.scheduler_load_on_reduce", required=False, type=str2bool)
102 + #
103 + parser.add_argument("--test.batch_size", required=False, type=int, default=128)
104 + parser.add_argument("--test.evaluations", required=False, nargs="*", choices=[])
105 + parser.add_argument("--test.eval_model", required=False, type=str,
106 + choices=["best", "last", "current"], default="best")
107 +
108 + return parser
109 +
110 +
111 +if __name__ == "__main__":
112 + # set seeds etc here
113 + torch.backends.cudnn.benchmark = True
114 +
115 + # define logger etc
116 + logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
117 + logger = logging.getLogger()
118 +
119 + parser = parser_setup()
120 + config = os_utils.parse_args(parser)
121 + if config.seed is not None:
122 + os_utils.set_seed(config.seed)
123 + if config.debug:
124 + logger.setLevel(logging.DEBUG)
125 +
126 + logger.info("Config:")
127 + logger.info(pprint.pformat(config.to_dict(), indent=4))
128 +
129 + # see https://github.com/wandb/client/issues/714
130 + os_utils.safe_makedirs(config.result_folder)
131 + statefile, run_id, result_folder = os_utils.get_state_params(
132 + config.wandb.use, config.wandb.run_id, config.result_folder, config.statefile
133 + )
134 + config.statefile = statefile
135 + config.wandb.run_id = run_id
136 + config.result_folder = result_folder
137 +
138 + if statefile is not None:
139 + data = torch.load(open(statefile, "rb"), pickle_module=dill)
140 + epoch = data["epoch"]
141 + if epoch >= config.train.max_epoch:
142 + logger.error("Aleady trained upto max epoch; exiting")
143 + sys.exit()
144 +
145 + if config.wandb.use:
146 + wandb.init(
147 + name=config.exp_name if config.exp_name is not None else config.result_folder,
148 + config=config.to_dict(),
149 + project=config.project,
150 + dir=config.result_folder,
151 + resume=config.wandb.run_id,
152 + id=config.wandb.run_id,
153 + sync_tensorboard=True,
154 + )
155 + logger.info(f"Starting wandb with id {wandb.run.id}")
156 +
157 + # NOTE: WANDB creates git patch so we probably can get rid of this in future
158 + os_utils.copy_code("src", config.result_folder, replace=True)
159 + json.dump(
160 + config.to_dict(),
161 + open(f"{wandb.run.dir if config.wandb.use else config.result_folder}/config.json", "w")
162 + )
163 +
164 + logger.info("Getting data and dataloaders")
165 + data, meta = get_dataset(**config.data, device=config.device)
166 +
167 + # num_workers = max(min(os.cpu_count(), 8), 1)
168 + num_workers = os.cpu_count()
169 + logger.info(f"Using {num_workers} workers")
170 + train_loader = DataLoader(data["train"], shuffle=True, batch_size=config.train.batch_size,
171 + num_workers=num_workers)
172 + valid_loader = DataLoader(data["valid"], shuffle=False, batch_size=config.test.batch_size,
173 + num_workers=num_workers)
174 + test_loader = DataLoader(data["test"], shuffle=False, batch_size=config.test.batch_size,
175 + num_workers=num_workers)
176 + logger.info("Getting model")
177 + # load arch module
178 + arch_module = importlib.import_module(config.model.arch.file.replace("/", ".")[:-3])
179 + model_arch = arch_module.get_arch(
180 + input_shape=meta.get("input_shape"), output_size=meta.get("num_class"), **config.model.arch,
181 + slice_dim=config.data.frame_dim
182 + )
183 +
184 + # declaring models
185 + if config.model.name in "regression":
186 + from src.models.regression import Regression
187 +
188 + model = Regression(**model_arch)
189 + else:
190 + raise Exception("Unknown model")
191 +
192 + model.to(config.device)
193 + model.stats()
194 +
195 + if config.wandb.use and config.wandb.watch:
196 + wandb.watch(model, log="all")
197 +
198 + # declaring trainer
199 + optimizer, scheduler = optimizer_utils.get_optimizer_scheduler(
200 + model,
201 + lr=config.train.lr,
202 + optimizer=config.train.optimizer,
203 + opt_params={
204 + "weight_decay": config.train.get("weight_decay", 1e-4),
205 + "momentum" : config.train.get("optimizer_momentum", 0.9)
206 + },
207 + scheduler=config.train.get("scheduler", None),
208 + scheduler_params={
209 + "gamma" : config.train.get("scheduler_gamma", 0.1),
210 + "milestones" : config.train.get("scheduler_milestones", [100, 200, 300]),
211 + "patience" : config.train.get("scheduler_patience", 100),
212 + "step_size" : config.train.get("scheduler_step_size", 100),
213 + "load_on_reduce": config.train.get("scheduler_load_on_reduce"),
214 + "mode" : "max" if config.train.get(
215 + "stopping_criteria_direction") == "bigger" else "min"
216 + },
217 + )
218 + trainer = Trainer(model, optimizer, scheduler=scheduler, statefile=config.statefile,
219 + result_dir=config.result_folder, log_every=config.train.log_every,
220 + save_strategy=config.train.save_strategy,
221 + patience=config.train.patience,
222 + max_epoch=config.train.max_epoch,
223 + stopping_criteria=config.train.stopping_criteria,
224 + gradient_norm_clip=config.train.gradient_norm_clip,
225 + stopping_criteria_direction=config.train.stopping_criteria_direction,
226 + evaluations=Box({"train": config.train.evaluations,
227 + "test" : config.test.evaluations}))
228 +
229 + if "train" in config.mode:
230 + logger.info("starting training")
231 + trainer.train(train_loader, valid_loader)
232 + logger.info("Training done;")
233 +
234 + # copy current step and write test results to
235 + step_to_write = trainer.step
236 + step_to_write += 1
237 +
238 + if "test" in config.mode and config.test.eval_model == "best":
239 + if os.path.exists(f"{trainer.result_dir}/best_model.pt"):
240 + logger.info("Loading best model")
241 + trainer.load(f"{trainer.result_dir}/best_model.pt")
242 + else:
243 + logger.info("eval_model is best, but best model not found ::: evaling last model")
244 + else:
245 + logger.info("eval model is not best, so skipping loading at end of training")
246 +
247 + if "test" in config.mode:
248 + logger.info("evaluating model on test set")
249 + logger.info(f"Model was trained upto {trainer.epoch}")
250 + # copy current step and write test results to
251 + step_to_write = trainer.step
252 + step_to_write += 1
253 + loss, aux_loss = trainer.test(train_loader, test_loader)
254 + logging_utils.loss_logger_helper(loss, aux_loss, writer=trainer.summary_writer,
255 + force_print=True, step=step_to_write,
256 + epoch=trainer.epoch,
257 + log_every=trainer.log_every, string="test",
258 + new_line=True)
259 +
260 + loss, aux_loss = trainer.test(train_loader, train_loader)
261 + logging_utils.loss_logger_helper(loss, aux_loss, writer=trainer.summary_writer,
262 + force_print=True, step=step_to_write,
263 + epoch=trainer.epoch,
264 + log_every=trainer.log_every, string="train_eval",
265 + new_line=True)
266 +
267 + loss, aux_loss = trainer.test(train_loader, valid_loader)
268 + logging_utils.loss_logger_helper(loss, aux_loss, writer=trainer.summary_writer,
269 + force_print=True, step=step_to_write,
270 + epoch=trainer.epoch,
271 + log_every=trainer.log_every, string="valid_eval",
272 + new_line=True)