Hyunji

base model

1 +""" base model"""
2 +import logging
3 +
4 +import numpy as np
5 +import torch.nn as nn
6 +
7 +logger = logging.getLogger()
8 +
9 +
10 +class Base(nn.Module):
11 + """ Base model with some util functions"""
12 +
13 + def stats(self, print_model=True):
14 + # print network model and information about parameters
15 + logger.info("Model info:::")
16 + if print_model:
17 + logger.info(self)
18 + count = 0
19 + for i in self.parameters():
20 + count += np.prod(i.shape)
21 + logger.info(f"Total parameters : {count}")
22 +
23 + def to(self, *args, **kwargs):
24 + if kwargs.get("device"):
25 + self.device = kwargs.get("device")
26 + if len(args) > 0:
27 + self.device = args[0]
28 + return super().to(*args, **kwargs)
29 +
30 + def forward(self, x):
31 + raise NotImplementedError()