Showing
1 changed file
with
129 additions
and
0 deletions
src/arch/brain_age_slice_lstm.py
0 → 100644
1 | +import torch | ||
2 | +from box import Box | ||
3 | +from torch import nn | ||
4 | + | ||
5 | + | ||
6 | +def encoder_blk(in_channels, out_channels): | ||
7 | + return nn.Sequential( | ||
8 | + nn.Conv2d(in_channels, out_channels, 3, padding=1, stride=1), | ||
9 | + nn.InstanceNorm2d(out_channels), | ||
10 | + nn.MaxPool2d(2, stride=2), | ||
11 | + nn.ReLU() | ||
12 | + ) | ||
13 | + | ||
14 | + | ||
15 | +class MRI_LSTM(nn.Module): | ||
16 | + | ||
17 | + def __init__(self, lstm_feat_dim, lstm_latent_dim, slice_dim, *args, **kwargs): | ||
18 | + super(MRI_LSTM, self).__init__() | ||
19 | + | ||
20 | + self.input_dim = [(1, 109, 91), (91, 1, 91), (91, 109, 1)][slice_dim - 1] | ||
21 | + | ||
22 | + self.feat_embed_dim = lstm_feat_dim | ||
23 | + self.latent_dim = lstm_latent_dim | ||
24 | + | ||
25 | + # Build Encoder | ||
26 | + encoder_blocks = [ | ||
27 | + encoder_blk(1, 32), | ||
28 | + encoder_blk(32, 64), | ||
29 | + encoder_blk(64, 128), | ||
30 | + encoder_blk(128, 256), | ||
31 | + encoder_blk(256, 256) | ||
32 | + ] | ||
33 | + self.encoder = nn.Sequential(*encoder_blocks) | ||
34 | + | ||
35 | + if slice_dim == 1: | ||
36 | + avg = nn.AvgPool2d([3, 2]) | ||
37 | + elif slice_dim == 2: | ||
38 | + avg = nn.AvgPool2d([2, 2]) | ||
39 | + elif slice_dim == 3: | ||
40 | + avg = nn.AvgPool2d([2, 3]) | ||
41 | + else: | ||
42 | + raise Exception("Invalid slice dim") | ||
43 | + self.slice_dim = slice_dim | ||
44 | + | ||
45 | + # Post processing | ||
46 | + self.post_proc = nn.Sequential( | ||
47 | + nn.Conv2d(256, 64, 1, stride=1), | ||
48 | + nn.InstanceNorm2d(64), | ||
49 | + nn.ReLU(), | ||
50 | + avg, | ||
51 | + nn.Dropout(p=0.5), | ||
52 | + nn.Conv2d(64, self.feat_embed_dim, 1) | ||
53 | + ) | ||
54 | + | ||
55 | + # Connect w/ LSTM | ||
56 | + self.n_layers = 1 | ||
57 | + self.lstm = nn.LSTM(self.feat_embed_dim, self.latent_dim, self.n_layers, batch_first=True) | ||
58 | + | ||
59 | + # Build regressor | ||
60 | + self.lstm_post = nn.Linear(self.latent_dim, 64) | ||
61 | + self.regressor = nn.Sequential(nn.ReLU(), nn.Linear(64, 1)) | ||
62 | + | ||
63 | + self.init_weights() | ||
64 | + | ||
65 | + def init_weights(self): | ||
66 | + for k, m in self.named_modules(): | ||
67 | + if isinstance(m, nn.Conv2d): | ||
68 | + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") | ||
69 | + if m.bias is not None: | ||
70 | + nn.init.constant_(m.bias, 0) | ||
71 | + elif isinstance(m, nn.Linear) and "regressor" in k: | ||
72 | + m.bias.data.fill_(62.68) | ||
73 | + elif isinstance(m, nn.Linear): | ||
74 | + nn.init.normal_(m.weight, 0, 0.01) | ||
75 | + nn.init.constant_(m.bias, 0) | ||
76 | + | ||
77 | + def init_hidden(self, x): | ||
78 | + h_0 = torch.zeros(self.n_layers, x.size(0), self.latent_dim, device=x.device) | ||
79 | + c_0 = torch.zeros(self.n_layers, x.size(0), self.latent_dim, device=x.device) | ||
80 | + h_0.requires_grad = True | ||
81 | + c_0.requires_grad = True | ||
82 | + return h_0, c_0 | ||
83 | + | ||
84 | + def encode(self, x): | ||
85 | + | ||
86 | + h_0, c_0 = self.init_hidden(x) | ||
87 | + B, C, H, W, D = x.size() | ||
88 | + if self.slice_dim == 1: | ||
89 | + new_input = torch.cat([x[:, :, i, :, :] for i in range(H)], dim=0) | ||
90 | + encoding = self.encoder(new_input) | ||
91 | + encoding = self.post_proc(encoding) | ||
92 | + encoding = torch.cat([i.unsqueeze(2) for i in torch.split(encoding, B, dim=0)], dim=2) | ||
93 | + # note: squeezing is bad because batch dim can be dropped | ||
94 | + encoding = encoding.squeeze(4).squeeze(3) | ||
95 | + elif self.slice_dim == 2: | ||
96 | + new_input = torch.cat([x[:, :, :, i, :] for i in range(W)], dim=0) | ||
97 | + encoding = self.encoder(new_input) | ||
98 | + encoding = self.post_proc(encoding) | ||
99 | + encoding = torch.cat([i.unsqueeze(3) for i in torch.split(encoding, B, dim=0)], dim=3) | ||
100 | + # note: squeezing is bad because batch dim can be dropped | ||
101 | + encoding = encoding.squeeze(4).squeeze(2) | ||
102 | + elif self.slice_dim == 3: | ||
103 | + new_input = torch.cat([x[:, :, :, :, i] for i in range(D)], dim=0) | ||
104 | + encoding = self.encoder(new_input) | ||
105 | + encoding = self.post_proc(encoding) | ||
106 | + encoding = torch.cat([i.unsqueeze(4) for i in torch.split(encoding, B, dim=0)], dim=4) | ||
107 | + # note: squeezing is bad because batch dim can be dropped | ||
108 | + encoding = encoding.squeeze(3).squeeze(2) | ||
109 | + else: | ||
110 | + raise Exception("Invalid slice dim") | ||
111 | + | ||
112 | + # lstm take batch x seq_len x dim | ||
113 | + encoding = encoding.permute(0, 2, 1) | ||
114 | + | ||
115 | + _, (encoding, _) = self.lstm(encoding) | ||
116 | + # output is 1 X batch x hidden | ||
117 | + encoding = encoding.squeeze(0) | ||
118 | + # pass it to lstm and get encoding | ||
119 | + return encoding | ||
120 | + | ||
121 | + def forward(self, x): | ||
122 | + embedding = self.encode(x) | ||
123 | + post = self.lstm_post(embedding) | ||
124 | + y_pred = self.regressor(post) | ||
125 | + return Box({"y_pred": y_pred}) | ||
126 | + | ||
127 | + | ||
128 | +def get_arch(*args, **kwargs): | ||
129 | + return {"net": MRI_LSTM(*args, **kwargs)} |
-
Please register or login to post a comment