Hyunji

base trainer code

1 +"""trainer code"""
2 +import copy
3 +import logging
4 +import os
5 +from typing import List, Dict, Optional, Callable, Union
6 +
7 +import dill
8 +import numpy as np
9 +import torch
10 +from torch.utils.tensorboard import SummaryWriter
11 +
12 +from lib.utils.logging import loss_logger_helper
13 +
14 +logger = logging.getLogger()
15 +
16 +
17 +class Trainer:
18 + # This is like skorch but instead of callbacks we use class functions (looks less magic)
19 + # this is an evolving template
20 + def __init__(
21 + self,
22 + model: torch.nn.Module,
23 + optimizer: torch.optim,
24 + scheduler: torch.optim.lr_scheduler,
25 + result_dir: Optional[str],
26 + statefile: Optional[str] = None,
27 + log_every: int = 100,
28 + save_strategy: Optional[List] = None,
29 + patience: int = 20,
30 + max_epoch: int = 100,
31 + gradient_norm_clip=-1,
32 + stopping_criteria_direction: str = "bigger",
33 + stopping_criteria: Optional[Union[str, Callable]] = "accuracy",
34 + evaluations=None,
35 + **kwargs,
36 + ):
37 + """
38 + stopping_criteria : can be a function, string or none. If string it should match one
39 + of the keys in aux_loss or should be loss, if none we don't invoke early stopping
40 + """
41 + super().__init__()
42 +
43 + self.result_dir = result_dir
44 + self.model = model
45 + self.optimizer = optimizer
46 + self.scheduler = scheduler
47 + self.evaluations = evaluations
48 + self.gradient_norm_clip = gradient_norm_clip
49 +
50 + # training state related params
51 + self.epoch = 0
52 + self.step = 0
53 + self.best_criteria = None
54 + self.best_epoch = -1
55 +
56 + # config related param
57 + self.log_every = log_every
58 + self.save_strategy = save_strategy
59 + self.patience = patience
60 + self.max_epoch = max_epoch
61 + self.stopping_criteria_direction = stopping_criteria_direction
62 + self.stopping_criteria = stopping_criteria
63 +
64 + # TODO: should save config and see if things have changed?
65 + if statefile is not None:
66 + self.load(statefile)
67 +
68 + # init best model
69 + self.best_model = self.model.state_dict()
70 +
71 + # logging stuff
72 + if result_dir is not None:
73 + # we do not need to purge. Purging can delete the validation result
74 + self.summary_writer = SummaryWriter(log_dir=result_dir)
75 +
76 + def load(self, fname: str) -> Dict:
77 + """
78 + fname: file name to load data from
79 + """
80 +
81 + data = torch.load(open(fname, "rb"), pickle_module=dill, map_location=self.model.device)
82 +
83 + if getattr(self, "model", None) and data.get("model") is not None:
84 + state_dict = self.model.state_dict()
85 + state_dict.update(data["model"])
86 + self.model.load_state_dict(state_dict)
87 +
88 + if getattr(self, "optimizer", None) and data.get("optimizer") is not None:
89 + optimizer_dict = self.optimizer.state_dict()
90 + optimizer_dict.update(data["optimizer"])
91 + self.optimizer.load_state_dict(optimizer_dict)
92 +
93 + if getattr(self, "scheduler", None) and data.get("scheduler") is not None:
94 + scheduler_dict = self.scheduler.state_dict()
95 + scheduler_dict.update(data["scheduler"])
96 + self.scheduler.load_state_dict(scheduler_dict)
97 +
98 + self.epoch = data["epoch"]
99 + self.step = data["step"]
100 + self.best_criteria = data["best_criteria"]
101 + self.best_epoch = data["best_epoch"]
102 + return data
103 +
104 + def save(self, fname: str, **kwargs):
105 + """
106 + fname: file name to save to
107 + kwargs: more arguments that we may want to save.
108 +
109 + By default we
110 + - save,
111 + - model,
112 + - optimizer,
113 + - epoch,
114 + - step,
115 + - best_criteria,
116 + - best_epoch
117 + """
118 + # NOTE: Best model is maintained but is saved automatically depending on save strategy,
119 + # So that It could be loaded outside of the training process
120 + kwargs.update({
121 + "model" : self.model.state_dict(),
122 + "optimizer" : self.optimizer.state_dict(),
123 + "epoch" : self.epoch,
124 + "step" : self.step,
125 + "best_criteria": self.best_criteria,
126 + "best_epoch" : self.best_epoch,
127 + })
128 +
129 + if self.scheduler is not None:
130 + kwargs.update({"scheduler": self.scheduler.state_dict()})
131 +
132 + torch.save(kwargs, open(fname, "wb"), pickle_module=dill)
133 +
134 + # todo : allow to extract predictions
135 + def run_iteration(self, batch, training: bool = True, reduce: bool = True):
136 + """
137 + batch : batch of data, directly passed to model as is
138 + training: if training set to true else false
139 + reduce: whether to compute loss mean or return the raw vector form
140 + """
141 + pred = self.model(batch)
142 + loss, aux_loss = self.model.loss(pred, batch, reduce=reduce)
143 + print(pred)
144 +
145 + if training:
146 + print(pred)
147 + loss.backward()
148 + if self.gradient_norm_clip > 0:
149 + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_norm_clip)
150 + self.optimizer.step()
151 + self.optimizer.zero_grad()
152 +
153 + return loss, aux_loss
154 +
155 + def compute_criteria(self, loss, aux_loss):
156 + stopping_criteria = self.stopping_criteria
157 + if stopping_criteria is None:
158 + return loss
159 +
160 + if callable(stopping_criteria):
161 + return stopping_criteria(loss, aux_loss)
162 +
163 + if stopping_criteria == "loss":
164 + return loss
165 +
166 + if aux_loss.get(stopping_criteria) is not None:
167 + return aux_loss[stopping_criteria]
168 +
169 + raise Exception(f"{stopping_criteria} not found")
170 +
171 + def train_batch(self, batch, *args, **kwargs):
172 + # This trains the batch
173 + loss, aux_loss = self.run_iteration(batch, training=True, reduce=True)
174 + loss_logger_helper(loss, aux_loss, writer=self.summary_writer, step=self.step,
175 + epoch=self.epoch,
176 + log_every=self.log_every, string="train")
177 +
178 + def train_epoch(self, train_loader, *args, **kwargs):
179 + # This trains the epoch and also calls on batch begin and on batch end
180 + # before and after calling train_batch respectively
181 + self.model.train()
182 + for i, batch in enumerate(train_loader):
183 + self.on_batch_begin(i, batch, *args, **kwargs)
184 + self.train_batch(batch, *args, **kwargs)
185 + self.on_batch_end(i, batch, *args, **kwargs)
186 + self.step += 1
187 + self.model.eval()
188 +
189 + def on_train_begin(self, train_loader, valid_loader, *args, **kwargs):
190 + # this could be used to add things to class object like scheduler etc
191 + if "init" in self.save_strategy:
192 + if self.epoch == 0:
193 + self.save(f"{self.result_dir}/init_model.pt")
194 +
195 + def on_epoch_begin(self, train_loader, valid_loader, *args, **kwargs):
196 + # This is called when epoch begins
197 + pass
198 +
199 + def on_batch_begin(self, epoch_step, batch, *args, **kwargs):
200 + # This is called when batch begins
201 + pass
202 +
203 + def on_train_end(self, train_loader, valid_loader, *args, **kwargs):
204 + # Called when training finishes. For base trainer we just save the last model
205 + if "last" in self.save_strategy:
206 + logger.info("Saving the last model")
207 + self.save(f"{self.result_dir}/last_model.pt")
208 +
209 + def on_epoch_end(self, train_loader, valid_loader, *args, **kwargs):
210 + # called when epoch ends
211 + # we call validation, scheduler here
212 + # also check if we have a new best model and save model if needed
213 +
214 + # call validate
215 + loss, aux_loss = self.validate(train_loader, valid_loader, *args, **kwargs)
216 + loss_logger_helper(loss, aux_loss, writer=self.summary_writer, step=self.step,
217 + epoch=self.epoch, log_every=self.log_every, string="val",
218 + force_print=True)
219 +
220 + # do scheduler step
221 + if self.scheduler is not None:
222 + prev_lr = [group['lr'] for group in self.optimizer.param_groups]
223 + if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
224 + criteria = self.compute_criteria(loss, aux_loss)
225 + self.scheduler.step(criteria)
226 + else:
227 + self.scheduler.step()
228 + new_lr = [group['lr'] for group in self.optimizer.param_groups]
229 +
230 + # if you don't pass a criteria, it won't be computed and best model won't be saved.
231 + # on the contrary if you pass a stopping criteria, best model would be saved.
232 + # You can pass a large patience to get rid of early stopping
233 + if self.stopping_criteria is not None:
234 + criteria = self.compute_criteria(loss, aux_loss)
235 +
236 + if (
237 + (self.best_criteria is None)
238 + or (
239 + self.stopping_criteria_direction == "bigger" and self.best_criteria < criteria)
240 + or (
241 + self.stopping_criteria_direction == "lower" and self.best_criteria > criteria)
242 + ):
243 + self.best_criteria = criteria
244 + self.best_epoch = self.epoch
245 + self.best_model = copy.deepcopy(
246 + {k: v.cpu() for k, v in self.model.state_dict().items()})
247 +
248 + if "best" in self.save_strategy:
249 + logger.info(f"Saving best model at epoch {self.epoch}")
250 + self.save(f"{self.result_dir}/best_model.pt")
251 +
252 + if "epoch" in self.save_strategy:
253 + logger.info(f"Saving model at epoch {self.epoch}")
254 + self.save(f"{self.result_dir}/{self.epoch}_model.pt")
255 +
256 + if "current" in self.save_strategy:
257 + logger.info(f"Saving model at epoch {self.epoch}")
258 + self.save(f"{self.result_dir}/current_model.pt")
259 +
260 + # logic to load best model on reduce lr
261 + if self.scheduler is not None and not (all(a == b for (a, b) in zip(prev_lr, new_lr))):
262 + if getattr(self.scheduler, 'load_on_reduce', None) == "best":
263 + logger.info(f"Loading best model at epoch {self.epoch}")
264 + # we want to preserve the scheduler
265 + old_lrs = list(map(lambda x: x['lr'], self.optimizer.param_groups))
266 + old_scheduler_dict = copy.deepcopy(self.scheduler.state_dict())
267 +
268 + best_model_path = None
269 + if os.path.exists(f"{self.result_dir}/best_model.pt"):
270 + best_model_path = f"{self.result_dir}/best_model.pt"
271 + else:
272 + d = "/".join(self.result_dir.split("/")[:-1])
273 + for directory in os.listdir(d):
274 + if os.path.exists(f"{d}/{directory}/best_model.pt"):
275 + best_model_path = self.load(f"{d}/{directory}/best_model.pt")
276 +
277 + if best_model_path is None:
278 + raise FileNotFoundError(
279 + f"Best Model not found in {self.result_dir}, please copy if it exists in "
280 + f"other folder")
281 +
282 + self.load(best_model_path)
283 + # override scheduler to keep old one and also keep reduced learning rates
284 + self.scheduler.load_state_dict(old_scheduler_dict)
285 + for idx, lr in enumerate(old_lrs):
286 + self.optimizer.param_groups[idx]['lr'] = lr
287 + logger.info(f"loaded best model and restarting from end of {self.epoch}")
288 +
289 + def on_batch_end(self, epoch_step, batch, *args, **kwargs):
290 + # called after a batch is trained
291 + pass
292 +
293 + def train(self, train_loader, valid_loader, *args, **kwargs):
294 +
295 + self.on_train_begin(train_loader, valid_loader, *args, **kwargs)
296 + while self.epoch < self.max_epoch:
297 + # NOTE: +1 here is more convenient, as now we don't need to do +1 before saving model
298 + # If we don't do +1 before saving model, we will have to redo the last epoch
299 + # So +1 here makes life easy, if we load model at end of e epoch, we will load model
300 + # and start with e+1... smooth
301 + self.epoch += 1
302 + self.on_epoch_begin(train_loader, valid_loader, *args, **kwargs)
303 + logger.info(f"Starting epoch {self.epoch}")
304 + self.train_epoch(train_loader, *args, **kwargs)
305 + self.on_epoch_end(train_loader, valid_loader, *args, **kwargs)
306 +
307 + if self.epoch - self.best_epoch > self.patience:
308 + logger.info(f"Patience reached stopping training after {self.epoch} epochs")
309 + break
310 +
311 + self.on_train_end(train_loader, valid_loader, *args, **kwargs)
312 +
313 + def validate(self, train_loader, valid_loader, *args, **kwargs):
314 + """
315 + we expect validate to return mean and other aux losses that we want to log
316 + """
317 + losses = []
318 + aux_losses = {}
319 +
320 + self.model.eval()
321 + with torch.no_grad():
322 + for i, batch in enumerate(valid_loader):
323 + loss, aux_loss = self.run_iteration(batch, training=False, reduce=False)
324 + losses.extend(loss.cpu().tolist())
325 +
326 + if i == 0:
327 + for k, v in aux_loss.items():
328 + # when we can't return sample wise statistics, we need to do this
329 + if len(v.shape) == 0:
330 + aux_losses[k] = [v.cpu().tolist()]
331 + else:
332 + aux_losses[k] = v.cpu().tolist()
333 + else:
334 + for k, v in aux_loss.items():
335 + if len(v.shape) == 0:
336 + aux_losses[k].append(v.cpu().tolist())
337 + else:
338 + aux_losses[k].extend(v.cpu().tolist())
339 + return np.mean(losses), {k: np.mean(v) for (k, v) in aux_losses.items()}
340 +
341 + def test(self, train_loader, test_loader, *args, **kwargs):
342 + return self.validate(train_loader, test_loader, *args, **kwargs)