Showing
1 changed file
with
0 additions
and
276 deletions
main.py
deleted
100644 → 0
1 | -# -*- coding: utf-8 -*- | ||
2 | -"""main.ipynb | ||
3 | - | ||
4 | -Automatically generated by Colaboratory. | ||
5 | - | ||
6 | -Original file is located at | ||
7 | - https://colab.research.google.com/drive/16Rn0IDJTE6vJMOZtafDAI19AvGD9eK4C | ||
8 | -""" | ||
9 | - | ||
10 | -import argparse | ||
11 | -import importlib | ||
12 | -import json | ||
13 | -import logging | ||
14 | -import os | ||
15 | -import pprint | ||
16 | -import sys | ||
17 | - | ||
18 | -import dill | ||
19 | -import torch | ||
20 | -import wandb | ||
21 | -from box import Box | ||
22 | -from torch.utils.data import DataLoader | ||
23 | - | ||
24 | -from src.common.dataset import get_dataset | ||
25 | -from lib.utils import logging as logging_utils, os as os_utils, optimizer as optimizer_utils | ||
26 | -from lib.base_trainer import Trainer | ||
27 | -import easydict | ||
28 | - | ||
29 | -def parser_setup(): | ||
30 | - # define argparsers | ||
31 | - | ||
32 | - str2bool = os_utils.str2bool | ||
33 | - listorstr = os_utils.listorstr | ||
34 | - | ||
35 | - | ||
36 | - parser = easydict.EasyDict({ | ||
37 | - "debug":False, | ||
38 | - "config":None, | ||
39 | - "seed":0, | ||
40 | - "wandb_use":False, | ||
41 | - "wandb_run_id":None, | ||
42 | - "wandb.watch":False, | ||
43 | - "project":"brain-age", | ||
44 | - "exp_name":None, | ||
45 | - "device":"cuda", | ||
46 | - "result_folder":"a", | ||
47 | - "mode":["test", "train"], | ||
48 | - "statefile":None, | ||
49 | - | ||
50 | - "data" : { | ||
51 | - "name":"brain_age", | ||
52 | - "root_path":"/content/drive/MyDrive/2d-slice-set-networks-for-brain-age-master/data", | ||
53 | - "train_csv":"data/train0609.csv", | ||
54 | - "valid_csv":"data/valid0609.csv", | ||
55 | - "test_csv":"data/test0609.csv", | ||
56 | - "feat_csv":None, | ||
57 | - "train_num_sample":506, | ||
58 | - "frame_dim":1, | ||
59 | - "frame_keep_style":"random", | ||
60 | - "frame_keep_fraction":1, | ||
61 | - "impute":"drop", | ||
62 | - }, | ||
63 | - | ||
64 | - "model" : { | ||
65 | - "name":"regression", | ||
66 | - "arch": { | ||
67 | - "file":"src/arch/brain_age_3d.py", | ||
68 | - "lstm_feat_dim":2, | ||
69 | - "lstm_latent_dim":128, | ||
70 | - "attn_num_heads":1, | ||
71 | - "attn_dim":32, | ||
72 | - "attn_drop":False, | ||
73 | - "agg_fn":"attention" | ||
74 | - } | ||
75 | - }, | ||
76 | - | ||
77 | - "train":{ | ||
78 | - "batch_size":8, | ||
79 | - | ||
80 | - "patience":100, | ||
81 | - "max_epoch":100, | ||
82 | - "optimizer":"adam", | ||
83 | - "lr":1e-3, | ||
84 | - "weight_decay":1e-4, | ||
85 | - "gradient_norm_clip":-1, | ||
86 | - | ||
87 | - "save_strategy":["best", "last"], | ||
88 | - "log_every":100, | ||
89 | - | ||
90 | - "stopping_criteria":"loss", | ||
91 | - "stopping_criteria_direction":"lower", | ||
92 | - "evaluations":None, | ||
93 | - | ||
94 | - "optimizer_momentum":None, | ||
95 | - | ||
96 | - "scheduler":None, | ||
97 | - "scheduler_gamma":None, | ||
98 | - "scheduler_milestones":None, | ||
99 | - "scheduler_patience":None, | ||
100 | - "scheduler_step_size":None, | ||
101 | - "scheduler_load_on_reduce":None, | ||
102 | - }, | ||
103 | - | ||
104 | - "test":{ | ||
105 | - "batch_size":8, | ||
106 | - "evaluations":None, | ||
107 | - "eval_model":"best", | ||
108 | - }, | ||
109 | - | ||
110 | - "_actions":None, | ||
111 | - "_defaults":None | ||
112 | - }) | ||
113 | - | ||
114 | - print(parser.seed) | ||
115 | - return parser | ||
116 | - | ||
117 | -if __name__ == "__main__": | ||
118 | - # set seeds etc here | ||
119 | - torch.backends.cudnn.benchmark = True | ||
120 | - | ||
121 | - # define logger etc | ||
122 | - logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s") | ||
123 | - logger = logging.getLogger() | ||
124 | - | ||
125 | - parser = parser_setup() | ||
126 | - #config = os_utils.parse_args(parser) | ||
127 | - config = parser | ||
128 | - | ||
129 | - logger.info("Config:") | ||
130 | - logger.info(pprint.pformat(config, indent=4)) | ||
131 | - | ||
132 | - os_utils.safe_makedirs(config.result_folder) | ||
133 | - statefile, run_id, result_folder = os_utils.get_state_params( | ||
134 | - config.wandb_use, config.wandb_run_id, config.result_folder, config.statefile | ||
135 | - ) | ||
136 | - config.statefile = statefile | ||
137 | - config.wandb_run_id = run_id | ||
138 | - config.result_folder = result_folder | ||
139 | - | ||
140 | - if statefile is not None: | ||
141 | - data = torch.load(open(statefile, "rb"), pickle_module=dill) | ||
142 | - epoch = data["epoch"] | ||
143 | - if epoch >= config.train.max_epoch: | ||
144 | - logger.error("Aleady trained upto max epoch; exiting") | ||
145 | - sys.exit() | ||
146 | - | ||
147 | - if config.wandb_use: | ||
148 | - wandb.init( | ||
149 | - name=config.exp_name if config.exp_name is not None else config.result_folder, | ||
150 | - config=config.to_dict(), | ||
151 | - project=config.project, | ||
152 | - dir=config.result_folder, | ||
153 | - resume=config.wandb_run_id, | ||
154 | - id=config.wandb_run_id, | ||
155 | - sync_tensorboard=True, | ||
156 | - ) | ||
157 | - logger.info(f"Starting wandb with id {wandb_run_id}") | ||
158 | - | ||
159 | - # NOTE: WANDB creates git patch so we probably can get rid of this in future | ||
160 | - os_utils.copy_code("src", config.result_folder, replace=True,) | ||
161 | - json.dump( | ||
162 | - config, | ||
163 | - open(f"{wandb_run.dir if config.wandb_use else config.result_folder}/config.json", "w") | ||
164 | - ) | ||
165 | - | ||
166 | - logger.info("Getting data and dataloaders") | ||
167 | - data, meta = get_dataset(**config.data, device=config.device, replace=True, frac=2000) | ||
168 | - | ||
169 | - # num_workers = max(min(os.cpu_count(), 8), 1) | ||
170 | - num_workers = os.cpu_count() | ||
171 | - logger.info(f"Using {num_workers} workers") | ||
172 | - train_loader = DataLoader(data["train"], shuffle=True, batch_size=config.train.batch_size, | ||
173 | - num_workers=num_workers) | ||
174 | - valid_loader = DataLoader(data["valid"], shuffle=False, batch_size=config.test.batch_size, | ||
175 | - num_workers=num_workers) | ||
176 | - test_loader = DataLoader(data["test"], shuffle=False, batch_size=config.test.batch_size, | ||
177 | - num_workers=num_workers) | ||
178 | - logger.info("Getting model") | ||
179 | - # load arch module | ||
180 | - arch_module = importlib.import_module(config.model.arch.file.replace("/", ".")[:-3]) | ||
181 | - model_arch = arch_module.get_arch( | ||
182 | - input_shape=meta.get("input_shape"), output_size=meta.get("num_class"), | ||
183 | - **config.model.arch, | ||
184 | - slice_dim=config.data.frame_dim | ||
185 | - ) | ||
186 | - | ||
187 | - # declaring models | ||
188 | - if config.model.name in "regression": | ||
189 | - from src.models.regression import Regression | ||
190 | - | ||
191 | - model = Regression(**model_arch) | ||
192 | - else: | ||
193 | - raise Exception("Unknown model") | ||
194 | - | ||
195 | - model.to(config.device) | ||
196 | - model.stats() | ||
197 | - | ||
198 | - if config.wandb_use and config.wandb_watch: | ||
199 | - wandb_watch(model, log="all") | ||
200 | - | ||
201 | - # declaring trainer | ||
202 | - optimizer, scheduler = optimizer_utils.get_optimizer_scheduler( | ||
203 | - model, | ||
204 | - lr=config.train.lr, | ||
205 | - optimizer=config.train.optimizer, | ||
206 | - opt_params={ | ||
207 | - "weight_decay": config.train.get("weight_decay", 1e-4), | ||
208 | - "momentum" : config.train.get("optimizer_momentum", 0.9) | ||
209 | - }, | ||
210 | - scheduler=config.train.get("scheduler", None), | ||
211 | - scheduler_params={ | ||
212 | - "gamma" : config.train.get("scheduler_gamma", 0.1), | ||
213 | - "milestones" : config.train.get("scheduler_milestones", [100, 200, 300]), | ||
214 | - "patience" : config.train.get("scheduler_patience", 100), | ||
215 | - "step_size" : config.train.get("scheduler_step_size", 100), | ||
216 | - "load_on_reduce": config.train.get("scheduler_load_on_reduce"), | ||
217 | - "mode" : "max" if config.train.get( | ||
218 | - "stopping_criteria_direction") == "bigger" else "min" | ||
219 | - }, | ||
220 | - ) | ||
221 | - trainer = Trainer(model, optimizer, scheduler=scheduler, statefile=config.statefile, | ||
222 | - result_dir=config.result_folder, log_every=config.train.log_every, | ||
223 | - save_strategy=config.train.save_strategy, | ||
224 | - patience=config.train.patience, | ||
225 | - max_epoch=config.train.max_epoch, | ||
226 | - stopping_criteria=config.train.stopping_criteria, | ||
227 | - gradient_norm_clip=config.train.gradient_norm_clip, | ||
228 | - stopping_criteria_direction=config.train.stopping_criteria_direction, | ||
229 | - evaluations=Box({"train": config.train.evaluations, | ||
230 | - "test" : config.test.evaluations})) | ||
231 | - | ||
232 | - if "train" in config.mode: | ||
233 | - logger.info("starting training") | ||
234 | - print(train_loader.dataset) | ||
235 | - trainer.train(train_loader, valid_loader) | ||
236 | - logger.info("Training done;") | ||
237 | - | ||
238 | - # copy current step and write test results to | ||
239 | - step_to_write = trainer.step | ||
240 | - step_to_write += 1 | ||
241 | - | ||
242 | - if "test" in config.mode and config.test.eval_model == "best": | ||
243 | - if os.path.exists(f"{trainer.result_dir}/best_model.pt"): | ||
244 | - logger.info("Loading best model") | ||
245 | - trainer.load(f"{trainer.result_dir}/best_model.pt") | ||
246 | - else: | ||
247 | - logger.info("eval_model is best, but best model not found ::: evaling last model") | ||
248 | - else: | ||
249 | - logger.info("eval model is not best, so skipping loading at end of training") | ||
250 | - | ||
251 | - if "test" in config.mode: | ||
252 | - logger.info("evaluating model on test set") | ||
253 | - logger.info(f"Model was trained upto {trainer.epoch}") | ||
254 | - # copy current step and write test results to | ||
255 | - step_to_write = trainer.step | ||
256 | - step_to_write += 1 | ||
257 | - loss, aux_loss = trainer.test(train_loader, test_loader) | ||
258 | - logging_utils.loss_logger_helper(loss, aux_loss, writer=trainer.summary_writer, | ||
259 | - force_print=True, step=step_to_write, | ||
260 | - epoch=trainer.epoch, | ||
261 | - log_every=trainer.log_every, string="test", | ||
262 | - new_line=True) | ||
263 | - | ||
264 | - loss, aux_loss = trainer.test(train_loader, train_loader) | ||
265 | - logging_utils.loss_logger_helper(loss, aux_loss, writer=trainer.summary_writer, | ||
266 | - force_print=True, step=step_to_write, | ||
267 | - epoch=trainer.epoch, | ||
268 | - log_every=trainer.log_every, string="train_eval", | ||
269 | - new_line=True) | ||
270 | - | ||
271 | - loss, aux_loss = trainer.test(train_loader, valid_loader) | ||
272 | - logging_utils.loss_logger_helper(loss, aux_loss, writer=trainer.summary_writer, | ||
273 | - force_print=True, step=step_to_write, | ||
274 | - epoch=trainer.epoch, | ||
275 | - log_every=trainer.log_every, string="valid_eval", | ||
276 | - new_line=True) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or login to post a comment