Showing
1 changed file
with
276 additions
and
0 deletions
main.py
0 → 100644
1 | +# -*- coding: utf-8 -*- | ||
2 | +"""main.ipynb | ||
3 | + | ||
4 | +Automatically generated by Colaboratory. | ||
5 | + | ||
6 | +Original file is located at | ||
7 | + https://colab.research.google.com/drive/16Rn0IDJTE6vJMOZtafDAI19AvGD9eK4C | ||
8 | +""" | ||
9 | + | ||
10 | +import argparse | ||
11 | +import importlib | ||
12 | +import json | ||
13 | +import logging | ||
14 | +import os | ||
15 | +import pprint | ||
16 | +import sys | ||
17 | + | ||
18 | +import dill | ||
19 | +import torch | ||
20 | +import wandb | ||
21 | +from box import Box | ||
22 | +from torch.utils.data import DataLoader | ||
23 | + | ||
24 | +from src.common.dataset import get_dataset | ||
25 | +from lib.utils import logging as logging_utils, os as os_utils, optimizer as optimizer_utils | ||
26 | +from lib.base_trainer import Trainer | ||
27 | +import easydict | ||
28 | + | ||
29 | +def parser_setup(): | ||
30 | + # define argparsers | ||
31 | + | ||
32 | + str2bool = os_utils.str2bool | ||
33 | + listorstr = os_utils.listorstr | ||
34 | + | ||
35 | + | ||
36 | + parser = easydict.EasyDict({ | ||
37 | + "debug":False, | ||
38 | + "config":None, | ||
39 | + "seed":0, | ||
40 | + "wandb_use":False, | ||
41 | + "wandb_run_id":None, | ||
42 | + "wandb.watch":False, | ||
43 | + "project":"brain-age", | ||
44 | + "exp_name":None, | ||
45 | + "device":"cuda", | ||
46 | + "result_folder":"a", | ||
47 | + "mode":["test", "train"], | ||
48 | + "statefile":None, | ||
49 | + | ||
50 | + "data" : { | ||
51 | + "name":"brain_age", | ||
52 | + "root_path":"/content/drive/MyDrive/2d-slice-set-networks-for-brain-age-master/data", | ||
53 | + "train_csv":"data/train0609.csv", | ||
54 | + "valid_csv":"data/valid0609.csv", | ||
55 | + "test_csv":"data/test0609.csv", | ||
56 | + "feat_csv":None, | ||
57 | + "train_num_sample":506, | ||
58 | + "frame_dim":1, | ||
59 | + "frame_keep_style":"random", | ||
60 | + "frame_keep_fraction":1, | ||
61 | + "impute":"drop", | ||
62 | + }, | ||
63 | + | ||
64 | + "model" : { | ||
65 | + "name":"regression", | ||
66 | + "arch": { | ||
67 | + "file":"src/arch/brain_age_3d.py", | ||
68 | + "lstm_feat_dim":2, | ||
69 | + "lstm_latent_dim":128, | ||
70 | + "attn_num_heads":1, | ||
71 | + "attn_dim":32, | ||
72 | + "attn_drop":False, | ||
73 | + "agg_fn":"attention" | ||
74 | + } | ||
75 | + }, | ||
76 | + | ||
77 | + "train":{ | ||
78 | + "batch_size":8, | ||
79 | + | ||
80 | + "patience":100, | ||
81 | + "max_epoch":100, | ||
82 | + "optimizer":"adam", | ||
83 | + "lr":1e-3, | ||
84 | + "weight_decay":1e-4, | ||
85 | + "gradient_norm_clip":-1, | ||
86 | + | ||
87 | + "save_strategy":["best", "last"], | ||
88 | + "log_every":100, | ||
89 | + | ||
90 | + "stopping_criteria":"loss", | ||
91 | + "stopping_criteria_direction":"lower", | ||
92 | + "evaluations":None, | ||
93 | + | ||
94 | + "optimizer_momentum":None, | ||
95 | + | ||
96 | + "scheduler":None, | ||
97 | + "scheduler_gamma":None, | ||
98 | + "scheduler_milestones":None, | ||
99 | + "scheduler_patience":None, | ||
100 | + "scheduler_step_size":None, | ||
101 | + "scheduler_load_on_reduce":None, | ||
102 | + }, | ||
103 | + | ||
104 | + "test":{ | ||
105 | + "batch_size":8, | ||
106 | + "evaluations":None, | ||
107 | + "eval_model":"best", | ||
108 | + }, | ||
109 | + | ||
110 | + "_actions":None, | ||
111 | + "_defaults":None | ||
112 | + }) | ||
113 | + | ||
114 | + print(parser.seed) | ||
115 | + return parser | ||
116 | + | ||
117 | +if __name__ == "__main__": | ||
118 | + # set seeds etc here | ||
119 | + torch.backends.cudnn.benchmark = True | ||
120 | + | ||
121 | + # define logger etc | ||
122 | + logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s") | ||
123 | + logger = logging.getLogger() | ||
124 | + | ||
125 | + parser = parser_setup() | ||
126 | + #config = os_utils.parse_args(parser) | ||
127 | + config = parser | ||
128 | + | ||
129 | + logger.info("Config:") | ||
130 | + logger.info(pprint.pformat(config, indent=4)) | ||
131 | + | ||
132 | + os_utils.safe_makedirs(config.result_folder) | ||
133 | + statefile, run_id, result_folder = os_utils.get_state_params( | ||
134 | + config.wandb_use, config.wandb_run_id, config.result_folder, config.statefile | ||
135 | + ) | ||
136 | + config.statefile = statefile | ||
137 | + config.wandb_run_id = run_id | ||
138 | + config.result_folder = result_folder | ||
139 | + | ||
140 | + if statefile is not None: | ||
141 | + data = torch.load(open(statefile, "rb"), pickle_module=dill) | ||
142 | + epoch = data["epoch"] | ||
143 | + if epoch >= config.train.max_epoch: | ||
144 | + logger.error("Aleady trained upto max epoch; exiting") | ||
145 | + sys.exit() | ||
146 | + | ||
147 | + if config.wandb_use: | ||
148 | + wandb.init( | ||
149 | + name=config.exp_name if config.exp_name is not None else config.result_folder, | ||
150 | + config=config.to_dict(), | ||
151 | + project=config.project, | ||
152 | + dir=config.result_folder, | ||
153 | + resume=config.wandb_run_id, | ||
154 | + id=config.wandb_run_id, | ||
155 | + sync_tensorboard=True, | ||
156 | + ) | ||
157 | + logger.info(f"Starting wandb with id {wandb_run_id}") | ||
158 | + | ||
159 | + # NOTE: WANDB creates git patch so we probably can get rid of this in future | ||
160 | + os_utils.copy_code("src", config.result_folder, replace=True,) | ||
161 | + json.dump( | ||
162 | + config, | ||
163 | + open(f"{wandb_run.dir if config.wandb_use else config.result_folder}/config.json", "w") | ||
164 | + ) | ||
165 | + | ||
166 | + logger.info("Getting data and dataloaders") | ||
167 | + data, meta = get_dataset(**config.data, device=config.device, replace=True, frac=2000) | ||
168 | + | ||
169 | + # num_workers = max(min(os.cpu_count(), 8), 1) | ||
170 | + num_workers = os.cpu_count() | ||
171 | + logger.info(f"Using {num_workers} workers") | ||
172 | + train_loader = DataLoader(data["train"], shuffle=True, batch_size=config.train.batch_size, | ||
173 | + num_workers=num_workers) | ||
174 | + valid_loader = DataLoader(data["valid"], shuffle=False, batch_size=config.test.batch_size, | ||
175 | + num_workers=num_workers) | ||
176 | + test_loader = DataLoader(data["test"], shuffle=False, batch_size=config.test.batch_size, | ||
177 | + num_workers=num_workers) | ||
178 | + logger.info("Getting model") | ||
179 | + # load arch module | ||
180 | + arch_module = importlib.import_module(config.model.arch.file.replace("/", ".")[:-3]) | ||
181 | + model_arch = arch_module.get_arch( | ||
182 | + input_shape=meta.get("input_shape"), output_size=meta.get("num_class"), | ||
183 | + **config.model.arch, | ||
184 | + slice_dim=config.data.frame_dim | ||
185 | + ) | ||
186 | + | ||
187 | + # declaring models | ||
188 | + if config.model.name in "regression": | ||
189 | + from src.models.regression import Regression | ||
190 | + | ||
191 | + model = Regression(**model_arch) | ||
192 | + else: | ||
193 | + raise Exception("Unknown model") | ||
194 | + | ||
195 | + model.to(config.device) | ||
196 | + model.stats() | ||
197 | + | ||
198 | + if config.wandb_use and config.wandb_watch: | ||
199 | + wandb_watch(model, log="all") | ||
200 | + | ||
201 | + # declaring trainer | ||
202 | + optimizer, scheduler = optimizer_utils.get_optimizer_scheduler( | ||
203 | + model, | ||
204 | + lr=config.train.lr, | ||
205 | + optimizer=config.train.optimizer, | ||
206 | + opt_params={ | ||
207 | + "weight_decay": config.train.get("weight_decay", 1e-4), | ||
208 | + "momentum" : config.train.get("optimizer_momentum", 0.9) | ||
209 | + }, | ||
210 | + scheduler=config.train.get("scheduler", None), | ||
211 | + scheduler_params={ | ||
212 | + "gamma" : config.train.get("scheduler_gamma", 0.1), | ||
213 | + "milestones" : config.train.get("scheduler_milestones", [100, 200, 300]), | ||
214 | + "patience" : config.train.get("scheduler_patience", 100), | ||
215 | + "step_size" : config.train.get("scheduler_step_size", 100), | ||
216 | + "load_on_reduce": config.train.get("scheduler_load_on_reduce"), | ||
217 | + "mode" : "max" if config.train.get( | ||
218 | + "stopping_criteria_direction") == "bigger" else "min" | ||
219 | + }, | ||
220 | + ) | ||
221 | + trainer = Trainer(model, optimizer, scheduler=scheduler, statefile=config.statefile, | ||
222 | + result_dir=config.result_folder, log_every=config.train.log_every, | ||
223 | + save_strategy=config.train.save_strategy, | ||
224 | + patience=config.train.patience, | ||
225 | + max_epoch=config.train.max_epoch, | ||
226 | + stopping_criteria=config.train.stopping_criteria, | ||
227 | + gradient_norm_clip=config.train.gradient_norm_clip, | ||
228 | + stopping_criteria_direction=config.train.stopping_criteria_direction, | ||
229 | + evaluations=Box({"train": config.train.evaluations, | ||
230 | + "test" : config.test.evaluations})) | ||
231 | + | ||
232 | + if "train" in config.mode: | ||
233 | + logger.info("starting training") | ||
234 | + print(train_loader.dataset) | ||
235 | + trainer.train(train_loader, valid_loader) | ||
236 | + logger.info("Training done;") | ||
237 | + | ||
238 | + # copy current step and write test results to | ||
239 | + step_to_write = trainer.step | ||
240 | + step_to_write += 1 | ||
241 | + | ||
242 | + if "test" in config.mode and config.test.eval_model == "best": | ||
243 | + if os.path.exists(f"{trainer.result_dir}/best_model.pt"): | ||
244 | + logger.info("Loading best model") | ||
245 | + trainer.load(f"{trainer.result_dir}/best_model.pt") | ||
246 | + else: | ||
247 | + logger.info("eval_model is best, but best model not found ::: evaling last model") | ||
248 | + else: | ||
249 | + logger.info("eval model is not best, so skipping loading at end of training") | ||
250 | + | ||
251 | + if "test" in config.mode: | ||
252 | + logger.info("evaluating model on test set") | ||
253 | + logger.info(f"Model was trained upto {trainer.epoch}") | ||
254 | + # copy current step and write test results to | ||
255 | + step_to_write = trainer.step | ||
256 | + step_to_write += 1 | ||
257 | + loss, aux_loss = trainer.test(train_loader, test_loader) | ||
258 | + logging_utils.loss_logger_helper(loss, aux_loss, writer=trainer.summary_writer, | ||
259 | + force_print=True, step=step_to_write, | ||
260 | + epoch=trainer.epoch, | ||
261 | + log_every=trainer.log_every, string="test", | ||
262 | + new_line=True) | ||
263 | + | ||
264 | + loss, aux_loss = trainer.test(train_loader, train_loader) | ||
265 | + logging_utils.loss_logger_helper(loss, aux_loss, writer=trainer.summary_writer, | ||
266 | + force_print=True, step=step_to_write, | ||
267 | + epoch=trainer.epoch, | ||
268 | + log_every=trainer.log_every, string="train_eval", | ||
269 | + new_line=True) | ||
270 | + | ||
271 | + loss, aux_loss = trainer.test(train_loader, valid_loader) | ||
272 | + logging_utils.loss_logger_helper(loss, aux_loss, writer=trainer.summary_writer, | ||
273 | + force_print=True, step=step_to_write, | ||
274 | + epoch=trainer.epoch, | ||
275 | + log_every=trainer.log_every, string="valid_eval", | ||
276 | + new_line=True) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or login to post a comment