Showing
1 changed file
with
342 additions
and
0 deletions
lib/base_trainer.py
0 → 100644
1 | +"""trainer code""" | ||
2 | +import copy | ||
3 | +import logging | ||
4 | +import os | ||
5 | +from typing import List, Dict, Optional, Callable, Union | ||
6 | + | ||
7 | +import dill | ||
8 | +import numpy as np | ||
9 | +import torch | ||
10 | +from torch.utils.tensorboard import SummaryWriter | ||
11 | + | ||
12 | +from lib.utils.logging import loss_logger_helper | ||
13 | + | ||
14 | +logger = logging.getLogger() | ||
15 | + | ||
16 | + | ||
17 | +class Trainer: | ||
18 | + # This is like skorch but instead of callbacks we use class functions (looks less magic) | ||
19 | + # this is an evolving template | ||
20 | + def __init__( | ||
21 | + self, | ||
22 | + model: torch.nn.Module, | ||
23 | + optimizer: torch.optim, | ||
24 | + scheduler: torch.optim.lr_scheduler, | ||
25 | + result_dir: Optional[str], | ||
26 | + statefile: Optional[str] = None, | ||
27 | + log_every: int = 100, | ||
28 | + save_strategy: Optional[List] = None, | ||
29 | + patience: int = 20, | ||
30 | + max_epoch: int = 100, | ||
31 | + gradient_norm_clip=-1, | ||
32 | + stopping_criteria_direction: str = "bigger", | ||
33 | + stopping_criteria: Optional[Union[str, Callable]] = "accuracy", | ||
34 | + evaluations=None, | ||
35 | + **kwargs, | ||
36 | + ): | ||
37 | + """ | ||
38 | + stopping_criteria : can be a function, string or none. If string it should match one | ||
39 | + of the keys in aux_loss or should be loss, if none we don't invoke early stopping | ||
40 | + """ | ||
41 | + super().__init__() | ||
42 | + | ||
43 | + self.result_dir = result_dir | ||
44 | + self.model = model | ||
45 | + self.optimizer = optimizer | ||
46 | + self.scheduler = scheduler | ||
47 | + self.evaluations = evaluations | ||
48 | + self.gradient_norm_clip = gradient_norm_clip | ||
49 | + | ||
50 | + # training state related params | ||
51 | + self.epoch = 0 | ||
52 | + self.step = 0 | ||
53 | + self.best_criteria = None | ||
54 | + self.best_epoch = -1 | ||
55 | + | ||
56 | + # config related param | ||
57 | + self.log_every = log_every | ||
58 | + self.save_strategy = save_strategy | ||
59 | + self.patience = patience | ||
60 | + self.max_epoch = max_epoch | ||
61 | + self.stopping_criteria_direction = stopping_criteria_direction | ||
62 | + self.stopping_criteria = stopping_criteria | ||
63 | + | ||
64 | + # TODO: should save config and see if things have changed? | ||
65 | + if statefile is not None: | ||
66 | + self.load(statefile) | ||
67 | + | ||
68 | + # init best model | ||
69 | + self.best_model = self.model.state_dict() | ||
70 | + | ||
71 | + # logging stuff | ||
72 | + if result_dir is not None: | ||
73 | + # we do not need to purge. Purging can delete the validation result | ||
74 | + self.summary_writer = SummaryWriter(log_dir=result_dir) | ||
75 | + | ||
76 | + def load(self, fname: str) -> Dict: | ||
77 | + """ | ||
78 | + fname: file name to load data from | ||
79 | + """ | ||
80 | + | ||
81 | + data = torch.load(open(fname, "rb"), pickle_module=dill, map_location=self.model.device) | ||
82 | + | ||
83 | + if getattr(self, "model", None) and data.get("model") is not None: | ||
84 | + state_dict = self.model.state_dict() | ||
85 | + state_dict.update(data["model"]) | ||
86 | + self.model.load_state_dict(state_dict) | ||
87 | + | ||
88 | + if getattr(self, "optimizer", None) and data.get("optimizer") is not None: | ||
89 | + optimizer_dict = self.optimizer.state_dict() | ||
90 | + optimizer_dict.update(data["optimizer"]) | ||
91 | + self.optimizer.load_state_dict(optimizer_dict) | ||
92 | + | ||
93 | + if getattr(self, "scheduler", None) and data.get("scheduler") is not None: | ||
94 | + scheduler_dict = self.scheduler.state_dict() | ||
95 | + scheduler_dict.update(data["scheduler"]) | ||
96 | + self.scheduler.load_state_dict(scheduler_dict) | ||
97 | + | ||
98 | + self.epoch = data["epoch"] | ||
99 | + self.step = data["step"] | ||
100 | + self.best_criteria = data["best_criteria"] | ||
101 | + self.best_epoch = data["best_epoch"] | ||
102 | + return data | ||
103 | + | ||
104 | + def save(self, fname: str, **kwargs): | ||
105 | + """ | ||
106 | + fname: file name to save to | ||
107 | + kwargs: more arguments that we may want to save. | ||
108 | + | ||
109 | + By default we | ||
110 | + - save, | ||
111 | + - model, | ||
112 | + - optimizer, | ||
113 | + - epoch, | ||
114 | + - step, | ||
115 | + - best_criteria, | ||
116 | + - best_epoch | ||
117 | + """ | ||
118 | + # NOTE: Best model is maintained but is saved automatically depending on save strategy, | ||
119 | + # So that It could be loaded outside of the training process | ||
120 | + kwargs.update({ | ||
121 | + "model" : self.model.state_dict(), | ||
122 | + "optimizer" : self.optimizer.state_dict(), | ||
123 | + "epoch" : self.epoch, | ||
124 | + "step" : self.step, | ||
125 | + "best_criteria": self.best_criteria, | ||
126 | + "best_epoch" : self.best_epoch, | ||
127 | + }) | ||
128 | + | ||
129 | + if self.scheduler is not None: | ||
130 | + kwargs.update({"scheduler": self.scheduler.state_dict()}) | ||
131 | + | ||
132 | + torch.save(kwargs, open(fname, "wb"), pickle_module=dill) | ||
133 | + | ||
134 | + # todo : allow to extract predictions | ||
135 | + def run_iteration(self, batch, training: bool = True, reduce: bool = True): | ||
136 | + """ | ||
137 | + batch : batch of data, directly passed to model as is | ||
138 | + training: if training set to true else false | ||
139 | + reduce: whether to compute loss mean or return the raw vector form | ||
140 | + """ | ||
141 | + pred = self.model(batch) | ||
142 | + loss, aux_loss = self.model.loss(pred, batch, reduce=reduce) | ||
143 | + print(pred) | ||
144 | + | ||
145 | + if training: | ||
146 | + print(pred) | ||
147 | + loss.backward() | ||
148 | + if self.gradient_norm_clip > 0: | ||
149 | + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_norm_clip) | ||
150 | + self.optimizer.step() | ||
151 | + self.optimizer.zero_grad() | ||
152 | + | ||
153 | + return loss, aux_loss | ||
154 | + | ||
155 | + def compute_criteria(self, loss, aux_loss): | ||
156 | + stopping_criteria = self.stopping_criteria | ||
157 | + if stopping_criteria is None: | ||
158 | + return loss | ||
159 | + | ||
160 | + if callable(stopping_criteria): | ||
161 | + return stopping_criteria(loss, aux_loss) | ||
162 | + | ||
163 | + if stopping_criteria == "loss": | ||
164 | + return loss | ||
165 | + | ||
166 | + if aux_loss.get(stopping_criteria) is not None: | ||
167 | + return aux_loss[stopping_criteria] | ||
168 | + | ||
169 | + raise Exception(f"{stopping_criteria} not found") | ||
170 | + | ||
171 | + def train_batch(self, batch, *args, **kwargs): | ||
172 | + # This trains the batch | ||
173 | + loss, aux_loss = self.run_iteration(batch, training=True, reduce=True) | ||
174 | + loss_logger_helper(loss, aux_loss, writer=self.summary_writer, step=self.step, | ||
175 | + epoch=self.epoch, | ||
176 | + log_every=self.log_every, string="train") | ||
177 | + | ||
178 | + def train_epoch(self, train_loader, *args, **kwargs): | ||
179 | + # This trains the epoch and also calls on batch begin and on batch end | ||
180 | + # before and after calling train_batch respectively | ||
181 | + self.model.train() | ||
182 | + for i, batch in enumerate(train_loader): | ||
183 | + self.on_batch_begin(i, batch, *args, **kwargs) | ||
184 | + self.train_batch(batch, *args, **kwargs) | ||
185 | + self.on_batch_end(i, batch, *args, **kwargs) | ||
186 | + self.step += 1 | ||
187 | + self.model.eval() | ||
188 | + | ||
189 | + def on_train_begin(self, train_loader, valid_loader, *args, **kwargs): | ||
190 | + # this could be used to add things to class object like scheduler etc | ||
191 | + if "init" in self.save_strategy: | ||
192 | + if self.epoch == 0: | ||
193 | + self.save(f"{self.result_dir}/init_model.pt") | ||
194 | + | ||
195 | + def on_epoch_begin(self, train_loader, valid_loader, *args, **kwargs): | ||
196 | + # This is called when epoch begins | ||
197 | + pass | ||
198 | + | ||
199 | + def on_batch_begin(self, epoch_step, batch, *args, **kwargs): | ||
200 | + # This is called when batch begins | ||
201 | + pass | ||
202 | + | ||
203 | + def on_train_end(self, train_loader, valid_loader, *args, **kwargs): | ||
204 | + # Called when training finishes. For base trainer we just save the last model | ||
205 | + if "last" in self.save_strategy: | ||
206 | + logger.info("Saving the last model") | ||
207 | + self.save(f"{self.result_dir}/last_model.pt") | ||
208 | + | ||
209 | + def on_epoch_end(self, train_loader, valid_loader, *args, **kwargs): | ||
210 | + # called when epoch ends | ||
211 | + # we call validation, scheduler here | ||
212 | + # also check if we have a new best model and save model if needed | ||
213 | + | ||
214 | + # call validate | ||
215 | + loss, aux_loss = self.validate(train_loader, valid_loader, *args, **kwargs) | ||
216 | + loss_logger_helper(loss, aux_loss, writer=self.summary_writer, step=self.step, | ||
217 | + epoch=self.epoch, log_every=self.log_every, string="val", | ||
218 | + force_print=True) | ||
219 | + | ||
220 | + # do scheduler step | ||
221 | + if self.scheduler is not None: | ||
222 | + prev_lr = [group['lr'] for group in self.optimizer.param_groups] | ||
223 | + if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): | ||
224 | + criteria = self.compute_criteria(loss, aux_loss) | ||
225 | + self.scheduler.step(criteria) | ||
226 | + else: | ||
227 | + self.scheduler.step() | ||
228 | + new_lr = [group['lr'] for group in self.optimizer.param_groups] | ||
229 | + | ||
230 | + # if you don't pass a criteria, it won't be computed and best model won't be saved. | ||
231 | + # on the contrary if you pass a stopping criteria, best model would be saved. | ||
232 | + # You can pass a large patience to get rid of early stopping | ||
233 | + if self.stopping_criteria is not None: | ||
234 | + criteria = self.compute_criteria(loss, aux_loss) | ||
235 | + | ||
236 | + if ( | ||
237 | + (self.best_criteria is None) | ||
238 | + or ( | ||
239 | + self.stopping_criteria_direction == "bigger" and self.best_criteria < criteria) | ||
240 | + or ( | ||
241 | + self.stopping_criteria_direction == "lower" and self.best_criteria > criteria) | ||
242 | + ): | ||
243 | + self.best_criteria = criteria | ||
244 | + self.best_epoch = self.epoch | ||
245 | + self.best_model = copy.deepcopy( | ||
246 | + {k: v.cpu() for k, v in self.model.state_dict().items()}) | ||
247 | + | ||
248 | + if "best" in self.save_strategy: | ||
249 | + logger.info(f"Saving best model at epoch {self.epoch}") | ||
250 | + self.save(f"{self.result_dir}/best_model.pt") | ||
251 | + | ||
252 | + if "epoch" in self.save_strategy: | ||
253 | + logger.info(f"Saving model at epoch {self.epoch}") | ||
254 | + self.save(f"{self.result_dir}/{self.epoch}_model.pt") | ||
255 | + | ||
256 | + if "current" in self.save_strategy: | ||
257 | + logger.info(f"Saving model at epoch {self.epoch}") | ||
258 | + self.save(f"{self.result_dir}/current_model.pt") | ||
259 | + | ||
260 | + # logic to load best model on reduce lr | ||
261 | + if self.scheduler is not None and not (all(a == b for (a, b) in zip(prev_lr, new_lr))): | ||
262 | + if getattr(self.scheduler, 'load_on_reduce', None) == "best": | ||
263 | + logger.info(f"Loading best model at epoch {self.epoch}") | ||
264 | + # we want to preserve the scheduler | ||
265 | + old_lrs = list(map(lambda x: x['lr'], self.optimizer.param_groups)) | ||
266 | + old_scheduler_dict = copy.deepcopy(self.scheduler.state_dict()) | ||
267 | + | ||
268 | + best_model_path = None | ||
269 | + if os.path.exists(f"{self.result_dir}/best_model.pt"): | ||
270 | + best_model_path = f"{self.result_dir}/best_model.pt" | ||
271 | + else: | ||
272 | + d = "/".join(self.result_dir.split("/")[:-1]) | ||
273 | + for directory in os.listdir(d): | ||
274 | + if os.path.exists(f"{d}/{directory}/best_model.pt"): | ||
275 | + best_model_path = self.load(f"{d}/{directory}/best_model.pt") | ||
276 | + | ||
277 | + if best_model_path is None: | ||
278 | + raise FileNotFoundError( | ||
279 | + f"Best Model not found in {self.result_dir}, please copy if it exists in " | ||
280 | + f"other folder") | ||
281 | + | ||
282 | + self.load(best_model_path) | ||
283 | + # override scheduler to keep old one and also keep reduced learning rates | ||
284 | + self.scheduler.load_state_dict(old_scheduler_dict) | ||
285 | + for idx, lr in enumerate(old_lrs): | ||
286 | + self.optimizer.param_groups[idx]['lr'] = lr | ||
287 | + logger.info(f"loaded best model and restarting from end of {self.epoch}") | ||
288 | + | ||
289 | + def on_batch_end(self, epoch_step, batch, *args, **kwargs): | ||
290 | + # called after a batch is trained | ||
291 | + pass | ||
292 | + | ||
293 | + def train(self, train_loader, valid_loader, *args, **kwargs): | ||
294 | + | ||
295 | + self.on_train_begin(train_loader, valid_loader, *args, **kwargs) | ||
296 | + while self.epoch < self.max_epoch: | ||
297 | + # NOTE: +1 here is more convenient, as now we don't need to do +1 before saving model | ||
298 | + # If we don't do +1 before saving model, we will have to redo the last epoch | ||
299 | + # So +1 here makes life easy, if we load model at end of e epoch, we will load model | ||
300 | + # and start with e+1... smooth | ||
301 | + self.epoch += 1 | ||
302 | + self.on_epoch_begin(train_loader, valid_loader, *args, **kwargs) | ||
303 | + logger.info(f"Starting epoch {self.epoch}") | ||
304 | + self.train_epoch(train_loader, *args, **kwargs) | ||
305 | + self.on_epoch_end(train_loader, valid_loader, *args, **kwargs) | ||
306 | + | ||
307 | + if self.epoch - self.best_epoch > self.patience: | ||
308 | + logger.info(f"Patience reached stopping training after {self.epoch} epochs") | ||
309 | + break | ||
310 | + | ||
311 | + self.on_train_end(train_loader, valid_loader, *args, **kwargs) | ||
312 | + | ||
313 | + def validate(self, train_loader, valid_loader, *args, **kwargs): | ||
314 | + """ | ||
315 | + we expect validate to return mean and other aux losses that we want to log | ||
316 | + """ | ||
317 | + losses = [] | ||
318 | + aux_losses = {} | ||
319 | + | ||
320 | + self.model.eval() | ||
321 | + with torch.no_grad(): | ||
322 | + for i, batch in enumerate(valid_loader): | ||
323 | + loss, aux_loss = self.run_iteration(batch, training=False, reduce=False) | ||
324 | + losses.extend(loss.cpu().tolist()) | ||
325 | + | ||
326 | + if i == 0: | ||
327 | + for k, v in aux_loss.items(): | ||
328 | + # when we can't return sample wise statistics, we need to do this | ||
329 | + if len(v.shape) == 0: | ||
330 | + aux_losses[k] = [v.cpu().tolist()] | ||
331 | + else: | ||
332 | + aux_losses[k] = v.cpu().tolist() | ||
333 | + else: | ||
334 | + for k, v in aux_loss.items(): | ||
335 | + if len(v.shape) == 0: | ||
336 | + aux_losses[k].append(v.cpu().tolist()) | ||
337 | + else: | ||
338 | + aux_losses[k].extend(v.cpu().tolist()) | ||
339 | + return np.mean(losses), {k: np.mean(v) for (k, v) in aux_losses.items()} | ||
340 | + | ||
341 | + def test(self, train_loader, test_loader, *args, **kwargs): | ||
342 | + return self.validate(train_loader, test_loader, *args, **kwargs) |
-
Please register or login to post a comment