verify_mri_lstm.py
4.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
from torch import nn
"""
Code to test LSTM implementation with Lam et.al.
Our implementation use vectorization and should be faster... but need to be verified.
"""
def encoder_blk(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1, stride=1),
nn.InstanceNorm2d(out_channels),
nn.MaxPool2d(2, stride=2),
nn.ReLU()
)
class MRI_LSTM(nn.Module):
def __init__(self, lstm_feat_dim, lstm_latent_dim, *args, **kwargs):
super(MRI_LSTM, self).__init__()
self.input_dim = (1, 109, 91)
self.feat_embed_dim = lstm_feat_dim
self.latent_dim = lstm_latent_dim
# Build Encoder
encoder_blocks = [
encoder_blk(1, 32),
encoder_blk(32, 64),
encoder_blk(64, 128),
encoder_blk(128, 256),
encoder_blk(256, 256)
]
self.encoder = nn.Sequential(*encoder_blocks)
# Post processing
self.post_proc = nn.Sequential(
nn.Conv2d(256, 64, 1, stride=1),
nn.InstanceNorm2d(64),
nn.ReLU(),
nn.AvgPool2d([3, 2]),
nn.Dropout(p=0.5),
nn.Conv2d(64, self.feat_embed_dim, 1)
)
# Connect w/ LSTM
self.n_layers = 1
self.lstm = nn.LSTM(self.feat_embed_dim, self.latent_dim, self.n_layers, batch_first=True)
# Build regressor
self.lstm_post = nn.Linear(self.latent_dim, 64)
self.regressor = nn.Sequential(nn.ReLU(), nn.Linear(64, 1))
self.init_weights()
def init_weights(self):
for k, m in self.named_modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear) and "regressor" in k:
m.bias.data.fill_(62.68)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def init_hidden(self, x):
h_0 = torch.zeros(self.n_layers, x.size(0), self.latent_dim, device=x.device)
c_0 = torch.zeros(self.n_layers, x.size(0), self.latent_dim, device=x.device)
h_0.requires_grad = True
c_0.requires_grad = True
return h_0, c_0
def encode_old(self, x, ):
B, C, H, W, D = x.size()
h_t, c_t = self.init_hidden(x)
for i in range(H):
out = self.encoder(x[:, :, i, :, :])
out = self.post_proc(out)
out = out.view(B, 1, self.feat_embed_dim)
h_t = h_t.view(1, B, self.latent_dim)
c_t = c_t.view(1, B, self.latent_dim)
h_t, (_, c_t) = self.lstm(out, (h_t, c_t))
encoding = h_t.view(B, self.latent_dim)
return encoding
def encode_new(self, x):
h_0, c_0 = self.init_hidden(x)
B, C, H, W, D = x.size()
# convert to 2D images, apply encoder and then reshape for lstm
new_input = torch.cat([x[:, :, i, :, :] for i in range(H)], dim=0)
encoding = self.encoder(new_input)
encoding = self.post_proc(encoding)
# (BxH) X C_out X W_out X D_out
encoding = torch.stack(torch.split(encoding, B, dim=0), dim=2)
# B X C_out X H X W_out X D_out
encoding = encoding.squeeze(4).squeeze(3)
# lstm take batch x seq_len x dim
encoding = encoding.permute(0, 2, 1)
_, (encoding, _) = self.lstm(encoding)
# output is 1 X batch x hidden
encoding = encoding.squeeze(0)
# pass it to lstm and get encoding
return encoding
def forward(self, x):
embedding_old = self.encode_old(x)
embedding_new = self.encode_new(x)
return embedding_new, embedding_old
if __name__ == "__main__":
B = 4
new_model = MRI_LSTM(lstm_feat_dim=2, lstm_latent_dim=128)
new_model.eval()
inp = torch.rand(4, 1, 91, 109, 91)
output = new_model(inp)
print(torch.allclose(output[0], output[1]))
# breakpoint()