Hyunji

Delete main.py

Showing 1 changed file with 0 additions and 276 deletions
1 -# -*- coding: utf-8 -*-
2 -"""main.ipynb
3 -
4 -Automatically generated by Colaboratory.
5 -
6 -Original file is located at
7 - https://colab.research.google.com/drive/16Rn0IDJTE6vJMOZtafDAI19AvGD9eK4C
8 -"""
9 -
10 -import argparse
11 -import importlib
12 -import json
13 -import logging
14 -import os
15 -import pprint
16 -import sys
17 -
18 -import dill
19 -import torch
20 -import wandb
21 -from box import Box
22 -from torch.utils.data import DataLoader
23 -
24 -from src.common.dataset import get_dataset
25 -from lib.utils import logging as logging_utils, os as os_utils, optimizer as optimizer_utils
26 -from lib.base_trainer import Trainer
27 -import easydict
28 -
29 -def parser_setup():
30 - # define argparsers
31 -
32 - str2bool = os_utils.str2bool
33 - listorstr = os_utils.listorstr
34 -
35 -
36 - parser = easydict.EasyDict({
37 - "debug":False,
38 - "config":None,
39 - "seed":0,
40 - "wandb_use":False,
41 - "wandb_run_id":None,
42 - "wandb.watch":False,
43 - "project":"brain-age",
44 - "exp_name":None,
45 - "device":"cuda",
46 - "result_folder":"a",
47 - "mode":["test", "train"],
48 - "statefile":None,
49 -
50 - "data" : {
51 - "name":"brain_age",
52 - "root_path":"/content/drive/MyDrive/2d-slice-set-networks-for-brain-age-master/data",
53 - "train_csv":"data/train0609.csv",
54 - "valid_csv":"data/valid0609.csv",
55 - "test_csv":"data/test0609.csv",
56 - "feat_csv":None,
57 - "train_num_sample":506,
58 - "frame_dim":1,
59 - "frame_keep_style":"random",
60 - "frame_keep_fraction":1,
61 - "impute":"drop",
62 - },
63 -
64 - "model" : {
65 - "name":"regression",
66 - "arch": {
67 - "file":"src/arch/brain_age_3d.py",
68 - "lstm_feat_dim":2,
69 - "lstm_latent_dim":128,
70 - "attn_num_heads":1,
71 - "attn_dim":32,
72 - "attn_drop":False,
73 - "agg_fn":"attention"
74 - }
75 - },
76 -
77 - "train":{
78 - "batch_size":8,
79 -
80 - "patience":100,
81 - "max_epoch":100,
82 - "optimizer":"adam",
83 - "lr":1e-3,
84 - "weight_decay":1e-4,
85 - "gradient_norm_clip":-1,
86 -
87 - "save_strategy":["best", "last"],
88 - "log_every":100,
89 -
90 - "stopping_criteria":"loss",
91 - "stopping_criteria_direction":"lower",
92 - "evaluations":None,
93 -
94 - "optimizer_momentum":None,
95 -
96 - "scheduler":None,
97 - "scheduler_gamma":None,
98 - "scheduler_milestones":None,
99 - "scheduler_patience":None,
100 - "scheduler_step_size":None,
101 - "scheduler_load_on_reduce":None,
102 - },
103 -
104 - "test":{
105 - "batch_size":8,
106 - "evaluations":None,
107 - "eval_model":"best",
108 - },
109 -
110 - "_actions":None,
111 - "_defaults":None
112 - })
113 -
114 - print(parser.seed)
115 - return parser
116 -
117 -if __name__ == "__main__":
118 - # set seeds etc here
119 - torch.backends.cudnn.benchmark = True
120 -
121 - # define logger etc
122 - logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
123 - logger = logging.getLogger()
124 -
125 - parser = parser_setup()
126 - #config = os_utils.parse_args(parser)
127 - config = parser
128 -
129 - logger.info("Config:")
130 - logger.info(pprint.pformat(config, indent=4))
131 -
132 - os_utils.safe_makedirs(config.result_folder)
133 - statefile, run_id, result_folder = os_utils.get_state_params(
134 - config.wandb_use, config.wandb_run_id, config.result_folder, config.statefile
135 - )
136 - config.statefile = statefile
137 - config.wandb_run_id = run_id
138 - config.result_folder = result_folder
139 -
140 - if statefile is not None:
141 - data = torch.load(open(statefile, "rb"), pickle_module=dill)
142 - epoch = data["epoch"]
143 - if epoch >= config.train.max_epoch:
144 - logger.error("Aleady trained upto max epoch; exiting")
145 - sys.exit()
146 -
147 - if config.wandb_use:
148 - wandb.init(
149 - name=config.exp_name if config.exp_name is not None else config.result_folder,
150 - config=config.to_dict(),
151 - project=config.project,
152 - dir=config.result_folder,
153 - resume=config.wandb_run_id,
154 - id=config.wandb_run_id,
155 - sync_tensorboard=True,
156 - )
157 - logger.info(f"Starting wandb with id {wandb_run_id}")
158 -
159 - # NOTE: WANDB creates git patch so we probably can get rid of this in future
160 - os_utils.copy_code("src", config.result_folder, replace=True,)
161 - json.dump(
162 - config,
163 - open(f"{wandb_run.dir if config.wandb_use else config.result_folder}/config.json", "w")
164 - )
165 -
166 - logger.info("Getting data and dataloaders")
167 - data, meta = get_dataset(**config.data, device=config.device, replace=True, frac=2000)
168 -
169 - # num_workers = max(min(os.cpu_count(), 8), 1)
170 - num_workers = os.cpu_count()
171 - logger.info(f"Using {num_workers} workers")
172 - train_loader = DataLoader(data["train"], shuffle=True, batch_size=config.train.batch_size,
173 - num_workers=num_workers)
174 - valid_loader = DataLoader(data["valid"], shuffle=False, batch_size=config.test.batch_size,
175 - num_workers=num_workers)
176 - test_loader = DataLoader(data["test"], shuffle=False, batch_size=config.test.batch_size,
177 - num_workers=num_workers)
178 - logger.info("Getting model")
179 - # load arch module
180 - arch_module = importlib.import_module(config.model.arch.file.replace("/", ".")[:-3])
181 - model_arch = arch_module.get_arch(
182 - input_shape=meta.get("input_shape"), output_size=meta.get("num_class"),
183 - **config.model.arch,
184 - slice_dim=config.data.frame_dim
185 - )
186 -
187 - # declaring models
188 - if config.model.name in "regression":
189 - from src.models.regression import Regression
190 -
191 - model = Regression(**model_arch)
192 - else:
193 - raise Exception("Unknown model")
194 -
195 - model.to(config.device)
196 - model.stats()
197 -
198 - if config.wandb_use and config.wandb_watch:
199 - wandb_watch(model, log="all")
200 -
201 - # declaring trainer
202 - optimizer, scheduler = optimizer_utils.get_optimizer_scheduler(
203 - model,
204 - lr=config.train.lr,
205 - optimizer=config.train.optimizer,
206 - opt_params={
207 - "weight_decay": config.train.get("weight_decay", 1e-4),
208 - "momentum" : config.train.get("optimizer_momentum", 0.9)
209 - },
210 - scheduler=config.train.get("scheduler", None),
211 - scheduler_params={
212 - "gamma" : config.train.get("scheduler_gamma", 0.1),
213 - "milestones" : config.train.get("scheduler_milestones", [100, 200, 300]),
214 - "patience" : config.train.get("scheduler_patience", 100),
215 - "step_size" : config.train.get("scheduler_step_size", 100),
216 - "load_on_reduce": config.train.get("scheduler_load_on_reduce"),
217 - "mode" : "max" if config.train.get(
218 - "stopping_criteria_direction") == "bigger" else "min"
219 - },
220 - )
221 - trainer = Trainer(model, optimizer, scheduler=scheduler, statefile=config.statefile,
222 - result_dir=config.result_folder, log_every=config.train.log_every,
223 - save_strategy=config.train.save_strategy,
224 - patience=config.train.patience,
225 - max_epoch=config.train.max_epoch,
226 - stopping_criteria=config.train.stopping_criteria,
227 - gradient_norm_clip=config.train.gradient_norm_clip,
228 - stopping_criteria_direction=config.train.stopping_criteria_direction,
229 - evaluations=Box({"train": config.train.evaluations,
230 - "test" : config.test.evaluations}))
231 -
232 - if "train" in config.mode:
233 - logger.info("starting training")
234 - print(train_loader.dataset)
235 - trainer.train(train_loader, valid_loader)
236 - logger.info("Training done;")
237 -
238 - # copy current step and write test results to
239 - step_to_write = trainer.step
240 - step_to_write += 1
241 -
242 - if "test" in config.mode and config.test.eval_model == "best":
243 - if os.path.exists(f"{trainer.result_dir}/best_model.pt"):
244 - logger.info("Loading best model")
245 - trainer.load(f"{trainer.result_dir}/best_model.pt")
246 - else:
247 - logger.info("eval_model is best, but best model not found ::: evaling last model")
248 - else:
249 - logger.info("eval model is not best, so skipping loading at end of training")
250 -
251 - if "test" in config.mode:
252 - logger.info("evaluating model on test set")
253 - logger.info(f"Model was trained upto {trainer.epoch}")
254 - # copy current step and write test results to
255 - step_to_write = trainer.step
256 - step_to_write += 1
257 - loss, aux_loss = trainer.test(train_loader, test_loader)
258 - logging_utils.loss_logger_helper(loss, aux_loss, writer=trainer.summary_writer,
259 - force_print=True, step=step_to_write,
260 - epoch=trainer.epoch,
261 - log_every=trainer.log_every, string="test",
262 - new_line=True)
263 -
264 - loss, aux_loss = trainer.test(train_loader, train_loader)
265 - logging_utils.loss_logger_helper(loss, aux_loss, writer=trainer.summary_writer,
266 - force_print=True, step=step_to_write,
267 - epoch=trainer.epoch,
268 - log_every=trainer.log_every, string="train_eval",
269 - new_line=True)
270 -
271 - loss, aux_loss = trainer.test(train_loader, valid_loader)
272 - logging_utils.loss_logger_helper(loss, aux_loss, writer=trainer.summary_writer,
273 - force_print=True, step=step_to_write,
274 - epoch=trainer.epoch,
275 - log_every=trainer.log_every, string="valid_eval",
276 - new_line=True)
...\ No newline at end of file ...\ No newline at end of file