Hyunji

Upload new file

Showing 1 changed file with 276 additions and 0 deletions
1 +# -*- coding: utf-8 -*-
2 +"""main.ipynb
3 +
4 +Automatically generated by Colaboratory.
5 +
6 +Original file is located at
7 + https://colab.research.google.com/drive/16Rn0IDJTE6vJMOZtafDAI19AvGD9eK4C
8 +"""
9 +
10 +import argparse
11 +import importlib
12 +import json
13 +import logging
14 +import os
15 +import pprint
16 +import sys
17 +
18 +import dill
19 +import torch
20 +import wandb
21 +from box import Box
22 +from torch.utils.data import DataLoader
23 +
24 +from src.common.dataset import get_dataset
25 +from lib.utils import logging as logging_utils, os as os_utils, optimizer as optimizer_utils
26 +from lib.base_trainer import Trainer
27 +import easydict
28 +
29 +def parser_setup():
30 + # define argparsers
31 +
32 + str2bool = os_utils.str2bool
33 + listorstr = os_utils.listorstr
34 +
35 +
36 + parser = easydict.EasyDict({
37 + "debug":False,
38 + "config":None,
39 + "seed":0,
40 + "wandb_use":False,
41 + "wandb_run_id":None,
42 + "wandb.watch":False,
43 + "project":"brain-age",
44 + "exp_name":None,
45 + "device":"cuda",
46 + "result_folder":"a",
47 + "mode":["test", "train"],
48 + "statefile":None,
49 +
50 + "data" : {
51 + "name":"brain_age",
52 + "root_path":"/content/drive/MyDrive/2d-slice-set-networks-for-brain-age-master/data",
53 + "train_csv":"data/train0609.csv",
54 + "valid_csv":"data/valid0609.csv",
55 + "test_csv":"data/test0609.csv",
56 + "feat_csv":None,
57 + "train_num_sample":506,
58 + "frame_dim":1,
59 + "frame_keep_style":"random",
60 + "frame_keep_fraction":1,
61 + "impute":"drop",
62 + },
63 +
64 + "model" : {
65 + "name":"regression",
66 + "arch": {
67 + "file":"src/arch/brain_age_3d.py",
68 + "lstm_feat_dim":2,
69 + "lstm_latent_dim":128,
70 + "attn_num_heads":1,
71 + "attn_dim":32,
72 + "attn_drop":False,
73 + "agg_fn":"attention"
74 + }
75 + },
76 +
77 + "train":{
78 + "batch_size":8,
79 +
80 + "patience":100,
81 + "max_epoch":100,
82 + "optimizer":"adam",
83 + "lr":1e-3,
84 + "weight_decay":1e-4,
85 + "gradient_norm_clip":-1,
86 +
87 + "save_strategy":["best", "last"],
88 + "log_every":100,
89 +
90 + "stopping_criteria":"loss",
91 + "stopping_criteria_direction":"lower",
92 + "evaluations":None,
93 +
94 + "optimizer_momentum":None,
95 +
96 + "scheduler":None,
97 + "scheduler_gamma":None,
98 + "scheduler_milestones":None,
99 + "scheduler_patience":None,
100 + "scheduler_step_size":None,
101 + "scheduler_load_on_reduce":None,
102 + },
103 +
104 + "test":{
105 + "batch_size":8,
106 + "evaluations":None,
107 + "eval_model":"best",
108 + },
109 +
110 + "_actions":None,
111 + "_defaults":None
112 + })
113 +
114 + print(parser.seed)
115 + return parser
116 +
117 +if __name__ == "__main__":
118 + # set seeds etc here
119 + torch.backends.cudnn.benchmark = True
120 +
121 + # define logger etc
122 + logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
123 + logger = logging.getLogger()
124 +
125 + parser = parser_setup()
126 + #config = os_utils.parse_args(parser)
127 + config = parser
128 +
129 + logger.info("Config:")
130 + logger.info(pprint.pformat(config, indent=4))
131 +
132 + os_utils.safe_makedirs(config.result_folder)
133 + statefile, run_id, result_folder = os_utils.get_state_params(
134 + config.wandb_use, config.wandb_run_id, config.result_folder, config.statefile
135 + )
136 + config.statefile = statefile
137 + config.wandb_run_id = run_id
138 + config.result_folder = result_folder
139 +
140 + if statefile is not None:
141 + data = torch.load(open(statefile, "rb"), pickle_module=dill)
142 + epoch = data["epoch"]
143 + if epoch >= config.train.max_epoch:
144 + logger.error("Aleady trained upto max epoch; exiting")
145 + sys.exit()
146 +
147 + if config.wandb_use:
148 + wandb.init(
149 + name=config.exp_name if config.exp_name is not None else config.result_folder,
150 + config=config.to_dict(),
151 + project=config.project,
152 + dir=config.result_folder,
153 + resume=config.wandb_run_id,
154 + id=config.wandb_run_id,
155 + sync_tensorboard=True,
156 + )
157 + logger.info(f"Starting wandb with id {wandb_run_id}")
158 +
159 + # NOTE: WANDB creates git patch so we probably can get rid of this in future
160 + os_utils.copy_code("src", config.result_folder, replace=True,)
161 + json.dump(
162 + config,
163 + open(f"{wandb_run.dir if config.wandb_use else config.result_folder}/config.json", "w")
164 + )
165 +
166 + logger.info("Getting data and dataloaders")
167 + data, meta = get_dataset(**config.data, device=config.device, replace=True, frac=2000)
168 +
169 + # num_workers = max(min(os.cpu_count(), 8), 1)
170 + num_workers = os.cpu_count()
171 + logger.info(f"Using {num_workers} workers")
172 + train_loader = DataLoader(data["train"], shuffle=True, batch_size=config.train.batch_size,
173 + num_workers=num_workers)
174 + valid_loader = DataLoader(data["valid"], shuffle=False, batch_size=config.test.batch_size,
175 + num_workers=num_workers)
176 + test_loader = DataLoader(data["test"], shuffle=False, batch_size=config.test.batch_size,
177 + num_workers=num_workers)
178 + logger.info("Getting model")
179 + # load arch module
180 + arch_module = importlib.import_module(config.model.arch.file.replace("/", ".")[:-3])
181 + model_arch = arch_module.get_arch(
182 + input_shape=meta.get("input_shape"), output_size=meta.get("num_class"),
183 + **config.model.arch,
184 + slice_dim=config.data.frame_dim
185 + )
186 +
187 + # declaring models
188 + if config.model.name in "regression":
189 + from src.models.regression import Regression
190 +
191 + model = Regression(**model_arch)
192 + else:
193 + raise Exception("Unknown model")
194 +
195 + model.to(config.device)
196 + model.stats()
197 +
198 + if config.wandb_use and config.wandb_watch:
199 + wandb_watch(model, log="all")
200 +
201 + # declaring trainer
202 + optimizer, scheduler = optimizer_utils.get_optimizer_scheduler(
203 + model,
204 + lr=config.train.lr,
205 + optimizer=config.train.optimizer,
206 + opt_params={
207 + "weight_decay": config.train.get("weight_decay", 1e-4),
208 + "momentum" : config.train.get("optimizer_momentum", 0.9)
209 + },
210 + scheduler=config.train.get("scheduler", None),
211 + scheduler_params={
212 + "gamma" : config.train.get("scheduler_gamma", 0.1),
213 + "milestones" : config.train.get("scheduler_milestones", [100, 200, 300]),
214 + "patience" : config.train.get("scheduler_patience", 100),
215 + "step_size" : config.train.get("scheduler_step_size", 100),
216 + "load_on_reduce": config.train.get("scheduler_load_on_reduce"),
217 + "mode" : "max" if config.train.get(
218 + "stopping_criteria_direction") == "bigger" else "min"
219 + },
220 + )
221 + trainer = Trainer(model, optimizer, scheduler=scheduler, statefile=config.statefile,
222 + result_dir=config.result_folder, log_every=config.train.log_every,
223 + save_strategy=config.train.save_strategy,
224 + patience=config.train.patience,
225 + max_epoch=config.train.max_epoch,
226 + stopping_criteria=config.train.stopping_criteria,
227 + gradient_norm_clip=config.train.gradient_norm_clip,
228 + stopping_criteria_direction=config.train.stopping_criteria_direction,
229 + evaluations=Box({"train": config.train.evaluations,
230 + "test" : config.test.evaluations}))
231 +
232 + if "train" in config.mode:
233 + logger.info("starting training")
234 + print(train_loader.dataset)
235 + trainer.train(train_loader, valid_loader)
236 + logger.info("Training done;")
237 +
238 + # copy current step and write test results to
239 + step_to_write = trainer.step
240 + step_to_write += 1
241 +
242 + if "test" in config.mode and config.test.eval_model == "best":
243 + if os.path.exists(f"{trainer.result_dir}/best_model.pt"):
244 + logger.info("Loading best model")
245 + trainer.load(f"{trainer.result_dir}/best_model.pt")
246 + else:
247 + logger.info("eval_model is best, but best model not found ::: evaling last model")
248 + else:
249 + logger.info("eval model is not best, so skipping loading at end of training")
250 +
251 + if "test" in config.mode:
252 + logger.info("evaluating model on test set")
253 + logger.info(f"Model was trained upto {trainer.epoch}")
254 + # copy current step and write test results to
255 + step_to_write = trainer.step
256 + step_to_write += 1
257 + loss, aux_loss = trainer.test(train_loader, test_loader)
258 + logging_utils.loss_logger_helper(loss, aux_loss, writer=trainer.summary_writer,
259 + force_print=True, step=step_to_write,
260 + epoch=trainer.epoch,
261 + log_every=trainer.log_every, string="test",
262 + new_line=True)
263 +
264 + loss, aux_loss = trainer.test(train_loader, train_loader)
265 + logging_utils.loss_logger_helper(loss, aux_loss, writer=trainer.summary_writer,
266 + force_print=True, step=step_to_write,
267 + epoch=trainer.epoch,
268 + log_every=trainer.log_every, string="train_eval",
269 + new_line=True)
270 +
271 + loss, aux_loss = trainer.test(train_loader, valid_loader)
272 + logging_utils.loss_logger_helper(loss, aux_loss, writer=trainer.summary_writer,
273 + force_print=True, step=step_to_write,
274 + epoch=trainer.epoch,
275 + log_every=trainer.log_every, string="valid_eval",
276 + new_line=True)
...\ No newline at end of file ...\ No newline at end of file