Showing
1 changed file
with
272 additions
and
0 deletions
src/scripts/main.py
0 → 100644
1 | +"""entry point for training a classifier""" | ||
2 | +import argparse | ||
3 | +import importlib | ||
4 | +import json | ||
5 | +import logging | ||
6 | +import os | ||
7 | +import pprint | ||
8 | +import sys | ||
9 | + | ||
10 | +import dill | ||
11 | +import torch | ||
12 | +import wandb | ||
13 | +from box import Box | ||
14 | +from torch.utils.data import DataLoader | ||
15 | + | ||
16 | +from lib.base_trainer import Trainer | ||
17 | +from lib.utils import logging as logging_utils, os as os_utils, optimizer as optimizer_utils | ||
18 | +from src.common.dataset import get_dataset | ||
19 | + | ||
20 | + | ||
21 | +def parser_setup(): | ||
22 | + # define argparsers | ||
23 | + parser = argparse.ArgumentParser() | ||
24 | + parser.add_argument("-D", "--debug", action='store_true') | ||
25 | + parser.add_argument("--config", "-c", required=False) | ||
26 | + parser.add_argument("--seed", required=False, type=int) | ||
27 | + | ||
28 | + str2bool = os_utils.str2bool | ||
29 | + listorstr = os_utils.listorstr | ||
30 | + parser.add_argument("--wandb.use", required=False, type=str2bool, default=False) | ||
31 | + parser.add_argument("--wandb.run_id", required=False, type=str) | ||
32 | + parser.add_argument("--wandb.watch", required=False, type=str2bool, default=False) | ||
33 | + parser.add_argument("--project", required=False, type=str, default="brain-age") | ||
34 | + parser.add_argument("--exp_name", required=True) | ||
35 | + | ||
36 | + parser.add_argument("--device", required=False, | ||
37 | + default="cuda" if torch.cuda.is_available() else "cpu") | ||
38 | + parser.add_argument("--result_folder", "-r", required=False) | ||
39 | + parser.add_argument("--mode", required=False, nargs="+", choices=["test", "train"], | ||
40 | + default=["test", "train"]) | ||
41 | + parser.add_argument("--statefile", "-s", required=False, default=None) | ||
42 | + | ||
43 | + parser.add_argument("--data.name", "-d", required=False, choices=["brain_age"]) | ||
44 | + | ||
45 | + # brain_age related arguments | ||
46 | + parser.add_argument("--data.root_path", default=None, type=str) | ||
47 | + parser.add_argument("--data.train_csv", default=None, type=str) | ||
48 | + parser.add_argument("--data.valid_csv", default=None, type=str) | ||
49 | + parser.add_argument("--data.test_csv", default=None, type=str) | ||
50 | + parser.add_argument("--data.feat_csv", default=None, type=str) | ||
51 | + parser.add_argument("--data.train_num_sample", default=-1, type=int, | ||
52 | + help="control number of training samples") | ||
53 | + parser.add_argument("--data.frame_dim", default=1, type=int, choices=[1, 2, 3], | ||
54 | + help="choose which dimension we want to slice, 1 for sagittal, " | ||
55 | + "2 for coronal, 3 for axial") | ||
56 | + parser.add_argument("--data.frame_keep_style", default="random", type=str, | ||
57 | + choices=["random", "ordered"], | ||
58 | + help="style of keeping frames when frame_keep_fraction < 1") | ||
59 | + parser.add_argument("--data.frame_keep_fraction", default=0, type=float, | ||
60 | + help="fraction of frame to keep (usually used during testing with missing " | ||
61 | + "frames)") | ||
62 | + parser.add_argument("--data.impute", default="drop", type=str, | ||
63 | + choices=["drop", "fill", "zeros", "noise"]) | ||
64 | + | ||
65 | + parser.add_argument("--model.name", required=False, choices=["regression"]) | ||
66 | + | ||
67 | + parser.add_argument("--model.arch.file", required=False, type=str, default=None) | ||
68 | + parser.add_argument("--model.arch.lstm_feat_dim", required=False, type=int, default=2) | ||
69 | + parser.add_argument("--model.arch.lstm_latent_dim", required=False, type=int, default=128) | ||
70 | + parser.add_argument("--model.arch.attn_num_heads", required=False, type=int, default=2) | ||
71 | + parser.add_argument("--model.arch.attn_dim", required=False, type=int, default=128) | ||
72 | + parser.add_argument("--model.arch.attn_drop", required=False, type=str2bool, default=False) | ||
73 | + parser.add_argument("--model.arch.agg_fn", required=False, type=str, | ||
74 | + choices=["mean", "max", "attention"]) | ||
75 | + | ||
76 | + parser.add_argument("--train.batch_size", required=False, type=int, default=128) | ||
77 | + | ||
78 | + parser.add_argument("--train.patience", required=False, type=int, default=20) | ||
79 | + parser.add_argument("--train.max_epoch", required=False, type=int, default=100) | ||
80 | + parser.add_argument("--train.optimizer", required=False, type=str, default="adam", | ||
81 | + choices=["adam", "sgd"]) | ||
82 | + parser.add_argument("--train.lr", required=False, type=float, default=1e-3) | ||
83 | + parser.add_argument("--train.weight_decay", required=False, type=float, default=5e-4) | ||
84 | + parser.add_argument("--train.gradient_norm_clip", required=False, type=float, default=-1) | ||
85 | + | ||
86 | + parser.add_argument("--train.save_strategy", required=False, nargs="+", | ||
87 | + choices=["best", "last", "init", "epoch", "current"], | ||
88 | + default=["best"]) | ||
89 | + parser.add_argument("--train.log_every", required=False, type=int, default=1000) | ||
90 | + | ||
91 | + parser.add_argument("--train.stopping_criteria", required=False, type=str, default="accuracy") | ||
92 | + parser.add_argument("--train.stopping_criteria_direction", required=False, | ||
93 | + choices=["bigger", "lower"], default="bigger") | ||
94 | + parser.add_argument("--train.evaluations", required=False, nargs="*", choices=[]) | ||
95 | + | ||
96 | + parser.add_argument("--train.scheduler", required=False, type=str, default=None) | ||
97 | + parser.add_argument("--train.scheduler_gamma", required=False, type=float) | ||
98 | + parser.add_argument("--train.scheduler_milestones", required=False, nargs="+") | ||
99 | + parser.add_argument("--train.scheduler_patience", required=False, type=int) | ||
100 | + parser.add_argument("--train.scheduler_step_size", required=False, type=int) | ||
101 | + parser.add_argument("--train.scheduler_load_on_reduce", required=False, type=str2bool) | ||
102 | + # | ||
103 | + parser.add_argument("--test.batch_size", required=False, type=int, default=128) | ||
104 | + parser.add_argument("--test.evaluations", required=False, nargs="*", choices=[]) | ||
105 | + parser.add_argument("--test.eval_model", required=False, type=str, | ||
106 | + choices=["best", "last", "current"], default="best") | ||
107 | + | ||
108 | + return parser | ||
109 | + | ||
110 | + | ||
111 | +if __name__ == "__main__": | ||
112 | + # set seeds etc here | ||
113 | + torch.backends.cudnn.benchmark = True | ||
114 | + | ||
115 | + # define logger etc | ||
116 | + logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s") | ||
117 | + logger = logging.getLogger() | ||
118 | + | ||
119 | + parser = parser_setup() | ||
120 | + config = os_utils.parse_args(parser) | ||
121 | + if config.seed is not None: | ||
122 | + os_utils.set_seed(config.seed) | ||
123 | + if config.debug: | ||
124 | + logger.setLevel(logging.DEBUG) | ||
125 | + | ||
126 | + logger.info("Config:") | ||
127 | + logger.info(pprint.pformat(config.to_dict(), indent=4)) | ||
128 | + | ||
129 | + # see https://github.com/wandb/client/issues/714 | ||
130 | + os_utils.safe_makedirs(config.result_folder) | ||
131 | + statefile, run_id, result_folder = os_utils.get_state_params( | ||
132 | + config.wandb.use, config.wandb.run_id, config.result_folder, config.statefile | ||
133 | + ) | ||
134 | + config.statefile = statefile | ||
135 | + config.wandb.run_id = run_id | ||
136 | + config.result_folder = result_folder | ||
137 | + | ||
138 | + if statefile is not None: | ||
139 | + data = torch.load(open(statefile, "rb"), pickle_module=dill) | ||
140 | + epoch = data["epoch"] | ||
141 | + if epoch >= config.train.max_epoch: | ||
142 | + logger.error("Aleady trained upto max epoch; exiting") | ||
143 | + sys.exit() | ||
144 | + | ||
145 | + if config.wandb.use: | ||
146 | + wandb.init( | ||
147 | + name=config.exp_name if config.exp_name is not None else config.result_folder, | ||
148 | + config=config.to_dict(), | ||
149 | + project=config.project, | ||
150 | + dir=config.result_folder, | ||
151 | + resume=config.wandb.run_id, | ||
152 | + id=config.wandb.run_id, | ||
153 | + sync_tensorboard=True, | ||
154 | + ) | ||
155 | + logger.info(f"Starting wandb with id {wandb.run.id}") | ||
156 | + | ||
157 | + # NOTE: WANDB creates git patch so we probably can get rid of this in future | ||
158 | + os_utils.copy_code("src", config.result_folder, replace=True) | ||
159 | + json.dump( | ||
160 | + config.to_dict(), | ||
161 | + open(f"{wandb.run.dir if config.wandb.use else config.result_folder}/config.json", "w") | ||
162 | + ) | ||
163 | + | ||
164 | + logger.info("Getting data and dataloaders") | ||
165 | + data, meta = get_dataset(**config.data, device=config.device) | ||
166 | + | ||
167 | + # num_workers = max(min(os.cpu_count(), 8), 1) | ||
168 | + num_workers = os.cpu_count() | ||
169 | + logger.info(f"Using {num_workers} workers") | ||
170 | + train_loader = DataLoader(data["train"], shuffle=True, batch_size=config.train.batch_size, | ||
171 | + num_workers=num_workers) | ||
172 | + valid_loader = DataLoader(data["valid"], shuffle=False, batch_size=config.test.batch_size, | ||
173 | + num_workers=num_workers) | ||
174 | + test_loader = DataLoader(data["test"], shuffle=False, batch_size=config.test.batch_size, | ||
175 | + num_workers=num_workers) | ||
176 | + logger.info("Getting model") | ||
177 | + # load arch module | ||
178 | + arch_module = importlib.import_module(config.model.arch.file.replace("/", ".")[:-3]) | ||
179 | + model_arch = arch_module.get_arch( | ||
180 | + input_shape=meta.get("input_shape"), output_size=meta.get("num_class"), **config.model.arch, | ||
181 | + slice_dim=config.data.frame_dim | ||
182 | + ) | ||
183 | + | ||
184 | + # declaring models | ||
185 | + if config.model.name in "regression": | ||
186 | + from src.models.regression import Regression | ||
187 | + | ||
188 | + model = Regression(**model_arch) | ||
189 | + else: | ||
190 | + raise Exception("Unknown model") | ||
191 | + | ||
192 | + model.to(config.device) | ||
193 | + model.stats() | ||
194 | + | ||
195 | + if config.wandb.use and config.wandb.watch: | ||
196 | + wandb.watch(model, log="all") | ||
197 | + | ||
198 | + # declaring trainer | ||
199 | + optimizer, scheduler = optimizer_utils.get_optimizer_scheduler( | ||
200 | + model, | ||
201 | + lr=config.train.lr, | ||
202 | + optimizer=config.train.optimizer, | ||
203 | + opt_params={ | ||
204 | + "weight_decay": config.train.get("weight_decay", 1e-4), | ||
205 | + "momentum" : config.train.get("optimizer_momentum", 0.9) | ||
206 | + }, | ||
207 | + scheduler=config.train.get("scheduler", None), | ||
208 | + scheduler_params={ | ||
209 | + "gamma" : config.train.get("scheduler_gamma", 0.1), | ||
210 | + "milestones" : config.train.get("scheduler_milestones", [100, 200, 300]), | ||
211 | + "patience" : config.train.get("scheduler_patience", 100), | ||
212 | + "step_size" : config.train.get("scheduler_step_size", 100), | ||
213 | + "load_on_reduce": config.train.get("scheduler_load_on_reduce"), | ||
214 | + "mode" : "max" if config.train.get( | ||
215 | + "stopping_criteria_direction") == "bigger" else "min" | ||
216 | + }, | ||
217 | + ) | ||
218 | + trainer = Trainer(model, optimizer, scheduler=scheduler, statefile=config.statefile, | ||
219 | + result_dir=config.result_folder, log_every=config.train.log_every, | ||
220 | + save_strategy=config.train.save_strategy, | ||
221 | + patience=config.train.patience, | ||
222 | + max_epoch=config.train.max_epoch, | ||
223 | + stopping_criteria=config.train.stopping_criteria, | ||
224 | + gradient_norm_clip=config.train.gradient_norm_clip, | ||
225 | + stopping_criteria_direction=config.train.stopping_criteria_direction, | ||
226 | + evaluations=Box({"train": config.train.evaluations, | ||
227 | + "test" : config.test.evaluations})) | ||
228 | + | ||
229 | + if "train" in config.mode: | ||
230 | + logger.info("starting training") | ||
231 | + trainer.train(train_loader, valid_loader) | ||
232 | + logger.info("Training done;") | ||
233 | + | ||
234 | + # copy current step and write test results to | ||
235 | + step_to_write = trainer.step | ||
236 | + step_to_write += 1 | ||
237 | + | ||
238 | + if "test" in config.mode and config.test.eval_model == "best": | ||
239 | + if os.path.exists(f"{trainer.result_dir}/best_model.pt"): | ||
240 | + logger.info("Loading best model") | ||
241 | + trainer.load(f"{trainer.result_dir}/best_model.pt") | ||
242 | + else: | ||
243 | + logger.info("eval_model is best, but best model not found ::: evaling last model") | ||
244 | + else: | ||
245 | + logger.info("eval model is not best, so skipping loading at end of training") | ||
246 | + | ||
247 | + if "test" in config.mode: | ||
248 | + logger.info("evaluating model on test set") | ||
249 | + logger.info(f"Model was trained upto {trainer.epoch}") | ||
250 | + # copy current step and write test results to | ||
251 | + step_to_write = trainer.step | ||
252 | + step_to_write += 1 | ||
253 | + loss, aux_loss = trainer.test(train_loader, test_loader) | ||
254 | + logging_utils.loss_logger_helper(loss, aux_loss, writer=trainer.summary_writer, | ||
255 | + force_print=True, step=step_to_write, | ||
256 | + epoch=trainer.epoch, | ||
257 | + log_every=trainer.log_every, string="test", | ||
258 | + new_line=True) | ||
259 | + | ||
260 | + loss, aux_loss = trainer.test(train_loader, train_loader) | ||
261 | + logging_utils.loss_logger_helper(loss, aux_loss, writer=trainer.summary_writer, | ||
262 | + force_print=True, step=step_to_write, | ||
263 | + epoch=trainer.epoch, | ||
264 | + log_every=trainer.log_every, string="train_eval", | ||
265 | + new_line=True) | ||
266 | + | ||
267 | + loss, aux_loss = trainer.test(train_loader, valid_loader) | ||
268 | + logging_utils.loss_logger_helper(loss, aux_loss, writer=trainer.summary_writer, | ||
269 | + force_print=True, step=step_to_write, | ||
270 | + epoch=trainer.epoch, | ||
271 | + log_every=trainer.log_every, string="valid_eval", | ||
272 | + new_line=True) |
-
Please register or login to post a comment