Showing
1 changed file
with
31 additions
and
0 deletions
lib/base_model.py
0 → 100644
1 | +""" base model""" | ||
2 | +import logging | ||
3 | + | ||
4 | +import numpy as np | ||
5 | +import torch.nn as nn | ||
6 | + | ||
7 | +logger = logging.getLogger() | ||
8 | + | ||
9 | + | ||
10 | +class Base(nn.Module): | ||
11 | + """ Base model with some util functions""" | ||
12 | + | ||
13 | + def stats(self, print_model=True): | ||
14 | + # print network model and information about parameters | ||
15 | + logger.info("Model info:::") | ||
16 | + if print_model: | ||
17 | + logger.info(self) | ||
18 | + count = 0 | ||
19 | + for i in self.parameters(): | ||
20 | + count += np.prod(i.shape) | ||
21 | + logger.info(f"Total parameters : {count}") | ||
22 | + | ||
23 | + def to(self, *args, **kwargs): | ||
24 | + if kwargs.get("device"): | ||
25 | + self.device = kwargs.get("device") | ||
26 | + if len(args) > 0: | ||
27 | + self.device = args[0] | ||
28 | + return super().to(*args, **kwargs) | ||
29 | + | ||
30 | + def forward(self, x): | ||
31 | + raise NotImplementedError() |
-
Please register or login to post a comment