Showing
1 changed file
with
117 additions
and
0 deletions
2DCNN/lib/standard_nn.py
0 → 100644
1 | +""" | ||
2 | +Utility to create simple sequential networks for classification or regression | ||
3 | + | ||
4 | +Create feed forward network with different `hidden_sizes` | ||
5 | + | ||
6 | +Create convolution networks with different `channels` (hidden_size), (2,2) max pooling | ||
7 | +""" | ||
8 | +import typing | ||
9 | + | ||
10 | +import numpy | ||
11 | +import torch.nn as nn | ||
12 | + | ||
13 | +from lib.utils.torch_utils import Reshape, infer_shape | ||
14 | + | ||
15 | + | ||
16 | +# TODO : extend to cover pooling sizes, strides etc for conv nets | ||
17 | +def get_arch(input_shape: typing.Union[numpy.array, typing.List], output_size: int, | ||
18 | + feed_forward: bool, hidden_sizes: typing.List[int], | ||
19 | + kernel_size: typing.Union[typing.List[int], int] = 3, | ||
20 | + non_linearity: typing.Union[typing.List[str], str, None] = "relu", | ||
21 | + norm: typing.Union[typing.List[str], str, None] = None, | ||
22 | + pooling: typing.Union[typing.List[str], str, None] = None) -> nn.Module: | ||
23 | + | ||
24 | + # general assertions | ||
25 | + n_layers = len(hidden_sizes) | ||
26 | + if n_layers > 0: | ||
27 | + if isinstance(non_linearity, list): | ||
28 | + assert len(non_linearity) == n_layers, "non linearity list is not same as hidden size" | ||
29 | + non_linearities = non_linearity | ||
30 | + else: | ||
31 | + non_linearities = [non_linearity] * n_layers | ||
32 | + | ||
33 | + if isinstance(norm, list): | ||
34 | + assert len(norm) == n_layers, "norm list is not same as hidden size" | ||
35 | + norms = norm | ||
36 | + else: | ||
37 | + norms = [norm] * n_layers | ||
38 | + else: | ||
39 | + norms = [] | ||
40 | + non_linearities = [] | ||
41 | + | ||
42 | + modules = [] | ||
43 | + | ||
44 | + if feed_forward: | ||
45 | + modules.append(Reshape()) | ||
46 | + insize = int(numpy.prod(input_shape)) | ||
47 | + | ||
48 | + for nl, no, outsize in zip(non_linearities, norms, hidden_sizes): | ||
49 | + modules.append(nn.Linear(insize, outsize)) | ||
50 | + | ||
51 | + if nl == "relu": | ||
52 | + modules.append(nn.ReLU()) | ||
53 | + elif nl is None: | ||
54 | + pass | ||
55 | + else: | ||
56 | + raise Exception(f"non-linearity {nl} not implemented") | ||
57 | + | ||
58 | + if no == "bn": | ||
59 | + modules.append(nn.BatchNorm1d(outsize)) | ||
60 | + elif no is None: | ||
61 | + pass | ||
62 | + else: | ||
63 | + raise Exception(f"norm {no} is not implemented") | ||
64 | + | ||
65 | + insize = outsize | ||
66 | + | ||
67 | + modules.append(nn.Linear(insize, output_size)) | ||
68 | + return {"net" : nn.Sequential(*modules)} | ||
69 | + | ||
70 | + # assertion specific to convolutions | ||
71 | + assert n_layers >= 1, "Number of layers has to be more than 1 for convolution" | ||
72 | + if isinstance(kernel_size, list): | ||
73 | + assert len(kernel_size) == n_layers, "kernel size is not same as hidden size" | ||
74 | + kernel_sizes = kernel_size | ||
75 | + else: | ||
76 | + kernel_sizes = [kernel_size] * n_layers | ||
77 | + | ||
78 | + if isinstance(pooling, list): | ||
79 | + assert len(pooling) == n_layers, "pooling size is not same as hidden size" | ||
80 | + poolings = pooling | ||
81 | + else: | ||
82 | + poolings = [pooling] * n_layers | ||
83 | + | ||
84 | + # convolutional layer with 3x3 convolutions | ||
85 | + inchannel = input_shape[0] | ||
86 | + for nl, no, outchannel, k, p in zip(non_linearities, norms, hidden_sizes, kernel_sizes, | ||
87 | + poolings): | ||
88 | + modules.append(nn.Conv2d(inchannel, outchannel, kernel_size=k)) | ||
89 | + | ||
90 | + if nl == "relu": | ||
91 | + modules.append(nn.ReLU()) | ||
92 | + elif nl is None: | ||
93 | + pass | ||
94 | + else: | ||
95 | + raise Exception(f"non-linearity {nl} is not implemented") | ||
96 | + | ||
97 | + if no == "bn": | ||
98 | + modules.append(nn.BatchNorm2d(outchannel)) | ||
99 | + elif no is None: | ||
100 | + pass | ||
101 | + else: | ||
102 | + raise Exception(f"norm {no} is not implemented") | ||
103 | + | ||
104 | + if p == "max_pool": | ||
105 | + modules.append(nn.MaxPool2d(2)) | ||
106 | + | ||
107 | + elif p is None: | ||
108 | + pass | ||
109 | + else: | ||
110 | + raise Exception(f"pooling {p} is not implemented") | ||
111 | + | ||
112 | + inchannel = outchannel | ||
113 | + | ||
114 | + output_shape = infer_shape(nn.Sequential(*modules).to("cpu"), input_shape) | ||
115 | + modules.append(Reshape()) | ||
116 | + modules.append(nn.Linear(int(numpy.prod(output_shape)), output_size)) | ||
117 | + return {"net" : nn.Sequential(*modules)} |
-
Please register or login to post a comment