Showing
1 changed file
with
99 additions
and
0 deletions
2DCNN/lib/utils/torch_utils.py
0 → 100644
1 | +""" things torch should have but it doesn't""" | ||
2 | +import logging | ||
3 | + | ||
4 | +import torch | ||
5 | +import torch.nn as nn | ||
6 | +from torch.autograd import Function | ||
7 | + | ||
8 | +logger = logging.getLogger() | ||
9 | +EPSILON = 1e-8 | ||
10 | + | ||
11 | + | ||
12 | +# reset seed | ||
13 | +def reset_seed(): | ||
14 | + while True: | ||
15 | + try: | ||
16 | + torch.seed() | ||
17 | + except RuntimeError as _: | ||
18 | + logger.error("Error generating seed") | ||
19 | + else: | ||
20 | + break | ||
21 | + | ||
22 | + | ||
23 | +class Reshape(nn.Module): | ||
24 | + """ | ||
25 | + Reshape module that reshapes any input to (batch_size, ...shape) | ||
26 | + by default it does flattening but you can pass any shape. | ||
27 | + """ | ||
28 | + | ||
29 | + def __init__(self, shape=(-1,)): | ||
30 | + super().__init__() | ||
31 | + self.shape = shape | ||
32 | + | ||
33 | + def forward(self, x): | ||
34 | + batch_size = x.shape[0] | ||
35 | + return x.view((batch_size,) + self.shape) | ||
36 | + | ||
37 | + def extra_repr(self): | ||
38 | + return f"shape={self.shape}" | ||
39 | + | ||
40 | + | ||
41 | +class Offset(torch.nn.Module): | ||
42 | + def __init__(self, offset, net): | ||
43 | + super().__init__() | ||
44 | + self.offset = nn.Parameter(offset, requires_grad=False) | ||
45 | + self.net = net | ||
46 | + | ||
47 | + def forward(self, *args): | ||
48 | + batch_size = args[0].shape[0] | ||
49 | + return self.offset.expand((batch_size, -1, -1, -1)) + 1e-8 # + self.net(*args) | ||
50 | + | ||
51 | + | ||
52 | +def batch_eye(N, D, device="cpu"): | ||
53 | + x = torch.eye(D, device=device) | ||
54 | + x = x.unsqueeze(0) | ||
55 | + x = x.repeat(N, 1, 1) | ||
56 | + return x | ||
57 | + | ||
58 | + | ||
59 | +def batch_eye_like(tensor): | ||
60 | + assert len(tensor.shape) == 3 and tensor.shape[1] == tensor.shape[2] | ||
61 | + N = tensor.shape[0] | ||
62 | + D = tensor.shape[1] | ||
63 | + return batch_eye(N, D, device=tensor.device) | ||
64 | + | ||
65 | + | ||
66 | +class _RevGrad(Function): | ||
67 | + @staticmethod | ||
68 | + def forward(ctx, input_): | ||
69 | + ctx.save_for_backward(input_) | ||
70 | + output = input_ | ||
71 | + return output | ||
72 | + | ||
73 | + @staticmethod | ||
74 | + def backward(ctx, grad_output): | ||
75 | + grad_input = None | ||
76 | + if ctx.needs_input_grad[0]: | ||
77 | + grad_input = -grad_output | ||
78 | + return grad_input | ||
79 | + | ||
80 | + | ||
81 | +revgrad = _RevGrad.apply | ||
82 | + | ||
83 | + | ||
84 | +class RevGrad(nn.Module): | ||
85 | + def __init__(self, *args, **kwargs): | ||
86 | + """ | ||
87 | + A gradient reversal layer. | ||
88 | + This layer has no parameters, and simply reverses the gradient | ||
89 | + in the backward pass. | ||
90 | + """ | ||
91 | + super().__init__(*args, **kwargs) | ||
92 | + | ||
93 | + def forward(self, input_): | ||
94 | + return revgrad(input_) | ||
95 | + | ||
96 | + | ||
97 | +def infer_shape(net, input_shape): | ||
98 | + x = torch.rand((2,) + input_shape) | ||
99 | + return net(x).shape[1:] |
-
Please register or login to post a comment