Hyunji

main

1 +# -*- coding: utf-8 -*-
2 +"""main.ipynb
3 +
4 +Automatically generated by Colaboratory.
5 +
6 +Original file is located at
7 + https://colab.research.google.com/drive/1X0vDvK01A8_3JbSAu1whlPD-rBx0JlM-
8 +"""
9 +
10 +import argparse
11 +import importlib
12 +import json
13 +import logging
14 +import os
15 +import pprint
16 +import sys
17 +
18 +import dill
19 +import torch
20 +import wandb
21 +from box import Box
22 +from torch.utils.data import DataLoader
23 +
24 +from src.common.dataset import get_dataset
25 +from lib.utils import logging as logging_utils, os as os_utils, optimizer as optimizer_utils
26 +from lib.base_trainer import Trainer
27 +import easydict
28 +
29 +def parser_setup():
30 + # define argparsers
31 +
32 + str2bool = os_utils.str2bool
33 + listorstr = os_utils.listorstr
34 +
35 + parser = easydict.EasyDict({
36 + "debug":False,
37 + "config":None,
38 + "seed":0,
39 + "wandb_use":False,
40 + "wandb_run_id":None,
41 + "wandb.watch":False,
42 + "project":"brain-age",
43 + "exp_name":None,
44 + "device":"cuda",
45 + "result_folder":"a",
46 + "mode":["test", "train"],
47 + "statefile":None,
48 +
49 + "data" : {
50 + "name":"brain_age",
51 + "root_path":"**root path",
52 + "train_csv":"**train csv",
53 + "valid_csv":"**valid csv",
54 + "test_csv":"**test csv",
55 + "feat_csv":None,
56 + "train_num_sample":-1,
57 + "frame_dim":1,
58 + "frame_keep_style":"random",
59 + "frame_keep_fraction":1,
60 + "impute":"drop",
61 + },
62 +
63 + "model" : {
64 + "name":"regression",
65 + "arch": {
66 + "file":"src/arch/brain_age_3d.py",
67 + "lstm_feat_dim":2,
68 + "lstm_latent_dim":128,
69 + "attn_num_heads":1,
70 + "attn_dim":32,
71 + "attn_drop":False,
72 + "agg_fn":"attention"
73 + }
74 + },
75 +
76 + "train":{
77 + "batch_size":8,
78 +
79 + "patience":100,
80 + "max_epoch":100,
81 + "optimizer":"adam",
82 + "lr":1e-3,
83 + "weight_decay":1e-4,
84 + "gradient_norm_clip":-1,
85 +
86 + "save_strategy":["best", "last"],
87 + "log_every":100,
88 +
89 + "stopping_criteria":"loss",
90 + "stopping_criteria_direction":"lower",
91 + "evaluations":None,
92 +
93 + "optimizer_momentum":None,
94 +
95 + "scheduler":None,
96 + "scheduler_gamma":None,
97 + "scheduler_milestones":None,
98 + "scheduler_patience":None,
99 + "scheduler_step_size":None,
100 + "scheduler_load_on_reduce":None,
101 + },
102 +
103 + "test":{
104 + "batch_size":8,
105 + "evaluations":None,
106 + "eval_model":"best",
107 + },
108 +
109 + "_actions":None,
110 + "_defaults":None
111 + })
112 +
113 + print(parser.seed)
114 + return parser
115 +
116 +if __name__ == "__main__":
117 + # set seeds etc here
118 + torch.backends.cudnn.benchmark = True
119 +
120 + # define logger etc
121 + logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
122 + logger = logging.getLogger()
123 +
124 + parser = parser_setup()
125 + #config = os_utils.parse_args(parser)
126 + config = parser
127 +
128 + logger.info("Config:")
129 + logger.info(pprint.pformat(config, indent=4))
130 +
131 + os_utils.safe_makedirs(config.result_folder)
132 + statefile, run_id, result_folder = os_utils.get_state_params(
133 + config.wandb_use, config.wandb_run_id, config.result_folder, config.statefile
134 + )
135 + config.statefile = statefile
136 + config.wandb_run_id = run_id
137 + config.result_folder = result_folder
138 +
139 + if statefile is not None:
140 + data = torch.load(open(statefile, "rb"), pickle_module=dill)
141 + epoch = data["epoch"]
142 + if epoch >= config.train.max_epoch:
143 + logger.error("Aleady trained upto max epoch; exiting")
144 + sys.exit()
145 +
146 + if config.wandb_use:
147 + wandb.init(
148 + name=config.exp_name if config.exp_name is not None else config.result_folder,
149 + config=config.to_dict(),
150 + project=config.project,
151 + dir=config.result_folder,
152 + resume=config.wandb_run_id,
153 + id=config.wandb_run_id,
154 + sync_tensorboard=True,
155 + )
156 + logger.info(f"Starting wandb with id {wandb_run_id}")
157 +
158 + # NOTE: WANDB creates git patch so we probably can get rid of this in future
159 + os_utils.copy_code("src", config.result_folder, replace=True,)
160 + json.dump(
161 + config,
162 + open(f"{wandb_run.dir if config.wandb_use else config.result_folder}/config.json", "w")
163 + )
164 +
165 + logger.info("Getting data and dataloaders")
166 + data, meta = get_dataset(**config.data, device=config.device, replace=True, frac=2000)
167 +
168 + # num_workers = max(min(os.cpu_count(), 8), 1)
169 + num_workers = os.cpu_count()
170 + logger.info(f"Using {num_workers} workers")
171 + train_loader = DataLoader(data["train"], shuffle=False, batch_size=config.train.batch_size,
172 + num_workers=num_workers)
173 + valid_loader = DataLoader(data["valid"], shuffle=False, batch_size=config.test.batch_size,
174 + num_workers=num_workers)
175 + test_loader = DataLoader(data["test"], shuffle=False, batch_size=config.test.batch_size,
176 + num_workers=num_workers)
177 +
178 + logger.info("Getting model")
179 +
180 + # load arch module
181 + arch_module = importlib.import_module(config.model.arch.file.replace("/", ".")[:-3])
182 + model_arch = arch_module.get_arch(
183 + input_shape=meta.get("input_shape"), output_size=meta.get("num_class"),
184 + **config.model.arch,
185 + slice_dim=config.data.frame_dim
186 + )
187 +
188 + # declaring models
189 + if config.model.name in "regression":
190 + from src.models.regression import Regression
191 +
192 + model = Regression(**model_arch)
193 + else:
194 + raise Exception("Unknown model")
195 +
196 + model.to(config.device)
197 + model.stats()
198 +
199 + if config.wandb_use and config.wandb_watch:
200 + wandb_watch(model, log="all")
201 +
202 + # declaring trainer
203 + optimizer, scheduler = optimizer_utils.get_optimizer_scheduler(
204 + model,
205 + lr=config.train.lr,
206 + optimizer=config.train.optimizer,
207 + opt_params={
208 + "weight_decay": config.train.get("weight_decay", 1e-4),
209 + "momentum" : config.train.get("optimizer_momentum", 0.9)
210 + },
211 + scheduler=config.train.get("scheduler", None),
212 + scheduler_params={
213 + "gamma" : config.train.get("scheduler_gamma", 0.1),
214 + "milestones" : config.train.get("scheduler_milestones", [100, 200, 300]),
215 + "patience" : config.train.get("scheduler_patience", 100),
216 + "step_size" : config.train.get("scheduler_step_size", 100),
217 + "load_on_reduce": config.train.get("scheduler_load_on_reduce"),
218 + "mode" : "max" if config.train.get(
219 + "stopping_criteria_direction") == "bigger" else "min"
220 + },
221 + )
222 + trainer = Trainer(model, optimizer, scheduler=scheduler, statefile=config.statefile,
223 + result_dir=config.result_folder, log_every=config.train.log_every,
224 + save_strategy=config.train.save_strategy,
225 + patience=config.train.patience,
226 + max_epoch=config.train.max_epoch,
227 + stopping_criteria=config.train.stopping_criteria,
228 + gradient_norm_clip=config.train.gradient_norm_clip,
229 + stopping_criteria_direction=config.train.stopping_criteria_direction,
230 + evaluations=Box({"train": config.train.evaluations,
231 + "test" : config.test.evaluations}))
232 +
233 + if "train" in config.mode:
234 + logger.info("starting training")
235 + print(train_loader.dataset)
236 + trainer.train(train_loader, valid_loader)
237 + logger.info("Training done;")
238 +
239 + # copy current step and write test results to
240 + step_to_write = trainer.step
241 + step_to_write += 1
242 +
243 + if "test" in config.mode and config.test.eval_model == "best":
244 + if os.path.exists(f"{trainer.result_dir}/best_model.pt"):
245 + logger.info("Loading best model")
246 + trainer.load(f"{trainer.result_dir}/best_model.pt")
247 + else:
248 + logger.info("eval_model is best, but best model not found ::: evaling last model")
249 + else:
250 + logger.info("eval model is not best, so skipping loading at end of training")
251 +
252 + if "test" in config.mode:
253 + logger.info("evaluating model on test set")
254 + logger.info(f"Model was trained upto {trainer.epoch}")
255 + # copy current step and write test results to
256 + step_to_write = trainer.step
257 + step_to_write += 1
258 +
259 + print("<<<<<test>>>>>")
260 + loss, aux_loss = trainer.test(train_loader, test_loader)
261 + logging_utils.loss_logger_helper(loss, aux_loss, writer=trainer.summary_writer,
262 + force_print=True, step=step_to_write,
263 + epoch=trainer.epoch,
264 + log_every=trainer.log_every, string="test",
265 + new_line=True)
266 +
267 + print("<<<<<training>>>>>")
268 + loss, aux_loss = trainer.test(train_loader, train_loader)
269 + logging_utils.loss_logger_helper(loss, aux_loss, writer=trainer.summary_writer,
270 + force_print=True, step=step_to_write,
271 + epoch=trainer.epoch,
272 + log_every=trainer.log_every, string="train_eval",
273 + new_line=True)
274 +
275 + print("<<<<<validation>>>>>")
276 + loss, aux_loss = trainer.test(train_loader, valid_loader)
277 + logging_utils.loss_logger_helper(loss, aux_loss, writer=trainer.summary_writer,
278 + force_print=True, step=step_to_write,
279 + epoch=trainer.epoch,
280 + log_every=trainer.log_every, string="valid_eval",
281 + new_line=True)
...\ No newline at end of file ...\ No newline at end of file