Hyunji

standard nn

1 +"""
2 +Utility to create simple sequential networks for classification or regression
3 +
4 +Create feed forward network with different `hidden_sizes`
5 +
6 +Create convolution networks with different `channels` (hidden_size), (2,2) max pooling
7 +"""
8 +import typing
9 +
10 +import numpy
11 +import torch.nn as nn
12 +
13 +from lib.utils.torch_utils import Reshape, infer_shape
14 +
15 +
16 +# TODO : extend to cover pooling sizes, strides etc for conv nets
17 +def get_arch(input_shape: typing.Union[numpy.array, typing.List], output_size: int,
18 + feed_forward: bool, hidden_sizes: typing.List[int],
19 + kernel_size: typing.Union[typing.List[int], int] = 3,
20 + non_linearity: typing.Union[typing.List[str], str, None] = "relu",
21 + norm: typing.Union[typing.List[str], str, None] = None,
22 + pooling: typing.Union[typing.List[str], str, None] = None) -> nn.Module:
23 +
24 + # general assertions
25 + n_layers = len(hidden_sizes)
26 + if n_layers > 0:
27 + if isinstance(non_linearity, list):
28 + assert len(non_linearity) == n_layers, "non linearity list is not same as hidden size"
29 + non_linearities = non_linearity
30 + else:
31 + non_linearities = [non_linearity] * n_layers
32 +
33 + if isinstance(norm, list):
34 + assert len(norm) == n_layers, "norm list is not same as hidden size"
35 + norms = norm
36 + else:
37 + norms = [norm] * n_layers
38 + else:
39 + norms = []
40 + non_linearities = []
41 +
42 + modules = []
43 +
44 + if feed_forward:
45 + modules.append(Reshape())
46 + insize = int(numpy.prod(input_shape))
47 +
48 + for nl, no, outsize in zip(non_linearities, norms, hidden_sizes):
49 + modules.append(nn.Linear(insize, outsize))
50 +
51 + if nl == "relu":
52 + modules.append(nn.ReLU())
53 + elif nl is None:
54 + pass
55 + else:
56 + raise Exception(f"non-linearity {nl} not implemented")
57 +
58 + if no == "bn":
59 + modules.append(nn.BatchNorm1d(outsize))
60 + elif no is None:
61 + pass
62 + else:
63 + raise Exception(f"norm {no} is not implemented")
64 +
65 + insize = outsize
66 +
67 + modules.append(nn.Linear(insize, output_size))
68 + return {"net" : nn.Sequential(*modules)}
69 +
70 + # assertion specific to convolutions
71 + assert n_layers >= 1, "Number of layers has to be more than 1 for convolution"
72 + if isinstance(kernel_size, list):
73 + assert len(kernel_size) == n_layers, "kernel size is not same as hidden size"
74 + kernel_sizes = kernel_size
75 + else:
76 + kernel_sizes = [kernel_size] * n_layers
77 +
78 + if isinstance(pooling, list):
79 + assert len(pooling) == n_layers, "pooling size is not same as hidden size"
80 + poolings = pooling
81 + else:
82 + poolings = [pooling] * n_layers
83 +
84 + # convolutional layer with 3x3 convolutions
85 + inchannel = input_shape[0]
86 + for nl, no, outchannel, k, p in zip(non_linearities, norms, hidden_sizes, kernel_sizes,
87 + poolings):
88 + modules.append(nn.Conv2d(inchannel, outchannel, kernel_size=k))
89 +
90 + if nl == "relu":
91 + modules.append(nn.ReLU())
92 + elif nl is None:
93 + pass
94 + else:
95 + raise Exception(f"non-linearity {nl} is not implemented")
96 +
97 + if no == "bn":
98 + modules.append(nn.BatchNorm2d(outchannel))
99 + elif no is None:
100 + pass
101 + else:
102 + raise Exception(f"norm {no} is not implemented")
103 +
104 + if p == "max_pool":
105 + modules.append(nn.MaxPool2d(2))
106 +
107 + elif p is None:
108 + pass
109 + else:
110 + raise Exception(f"pooling {p} is not implemented")
111 +
112 + inchannel = outchannel
113 +
114 + output_shape = infer_shape(nn.Sequential(*modules).to("cpu"), input_shape)
115 + modules.append(Reshape())
116 + modules.append(nn.Linear(int(numpy.prod(output_shape)), output_size))
117 + return {"net" : nn.Sequential(*modules)}