Hyunji

brain age slice set

1 +"""code for attention models"""
2 +
3 +import math
4 +
5 +import torch
6 +from box import Box
7 +from torch import nn
8 +
9 +
10 +class MeanPool(nn.Module):
11 + def forward(self, X):
12 + return X.mean(dim=1, keepdim=True), None
13 +
14 +
15 +class MaxPool(nn.Module):
16 + def forward(self, X):
17 + return X.max(dim=1, keepdim=True)[0], None
18 +
19 +
20 +class PooledAttention(nn.Module):
21 + def __init__(self, input_dim, dim_v, dim_k, num_heads, ln=False):
22 + super(PooledAttention, self).__init__()
23 + self.S = nn.Parameter(torch.zeros(1, dim_k))
24 + nn.init.xavier_uniform_(self.S)
25 +
26 + # transform to get key and value vector
27 + self.fc_k = nn.Linear(input_dim, dim_k)
28 + self.fc_v = nn.Linear(input_dim, dim_v)
29 +
30 + self.dim_v = dim_v
31 + self.dim_k = dim_k
32 + self.num_heads = num_heads
33 +
34 + if ln:
35 + self.ln0 = nn.LayerNorm(dim_v)
36 +
37 + def forward(self, X):
38 + B, C, H = X.shape
39 +
40 + Q = self.S.repeat(X.size(0), 1, 1)
41 +
42 + K = self.fc_k(X.reshape(-1, H)).reshape(B, C, self.dim_k)
43 + V = self.fc_v(X.reshape(-1, H)).reshape(B, C, self.dim_v)
44 + dim_split = self.dim_v // self.num_heads
45 + Q_ = torch.cat(Q.split(dim_split, 2), 0)
46 + K_ = torch.cat(K.split(dim_split, 2), 0)
47 + V_ = torch.cat(V.split(dim_split, 2), 0)
48 + A = torch.softmax(Q_.bmm(K_.transpose(1, 2)) / math.sqrt(dim_split), 2)
49 + O = torch.cat(A.bmm(V_).split(B, 0), 2)
50 + O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
51 + return O, A
52 +
53 + def get_attention(self, X):
54 + B, C, H = X.shape
55 +
56 + Q = self.S.repeat(X.size(0), 1, 1)
57 +
58 + K = self.fc_k(X.reshape(-1, H)).reshape(B, C, self.dim_k)
59 + V = self.fc_v(X.reshape(-1, H)).reshape(B, C, self.dim_v)
60 + dim_split = self.dim_v // self.num_heads
61 + Q_ = torch.cat(Q.split(dim_split, 2), 0)
62 + K_ = torch.cat(K.split(dim_split, 2), 0)
63 + V_ = torch.cat(V.split(dim_split, 2), 0)
64 + A = torch.softmax(Q_.bmm(K_.transpose(1, 2)) / math.sqrt(dim_split), 2)
65 + return A
66 +
67 +
68 +def encoder_blk(in_channels, out_channels):
69 + return nn.Sequential(
70 + nn.Conv2d(in_channels, out_channels, 3, padding=1, stride=1),
71 + nn.InstanceNorm2d(out_channels),
72 + nn.MaxPool2d(2, stride=2),
73 + nn.ReLU()
74 + )
75 +
76 +
77 +class MRI_ATTN(nn.Module):
78 +
79 + def __init__(self, attn_num_heads, attn_dim, attn_drop=False, agg_fn="attention", slice_dim=1,
80 + *args, **kwargs):
81 + super(MRI_ATTN, self).__init__()
82 +
83 + self.input_dim = [(1, 109, 91), (91, 1, 91), (91, 109, 1)][slice_dim - 1]
84 +
85 + self.num_heads = attn_num_heads
86 + self.attn_dim = attn_dim
87 +
88 + # Build Encoder
89 + encoder_blocks = [
90 + encoder_blk(1, 32),
91 + encoder_blk(32, 64),
92 + encoder_blk(64, 128),
93 + encoder_blk(128, 256),
94 + encoder_blk(256, 256)
95 + ]
96 + self.encoder = nn.Sequential(*encoder_blocks)
97 +
98 + if slice_dim == 1:
99 + avg = nn.AvgPool2d([3, 2])
100 + elif slice_dim == 2:
101 + avg = nn.AvgPool2d([2, 2])
102 + elif slice_dim == 3:
103 + avg = nn.AvgPool2d([2, 3])
104 + else:
105 + raise Exception("Invalid slice dim")
106 + self.slice_dim = slice_dim
107 +
108 + # Post processing
109 + self.post_proc = nn.Sequential(
110 + nn.Conv2d(256, 64, 1, stride=1),
111 + nn.InstanceNorm2d(64),
112 + nn.ReLU(),
113 + avg,
114 + nn.Dropout(p=0.5) if attn_drop else nn.Identity(),
115 + nn.Conv2d(64, self.num_heads * self.attn_dim, 1)
116 + )
117 +
118 + if agg_fn == "attention":
119 + self.pooled_attention = PooledAttention(input_dim=self.num_heads * self.attn_dim,
120 + dim_v=self.num_heads * self.attn_dim,
121 + dim_k=self.num_heads * self.attn_dim,
122 + num_heads=self.num_heads)
123 + elif agg_fn == "mean":
124 + self.pooled_attention = MeanPool()
125 + elif agg_fn == "max":
126 + self.pooled_attention = MaxPool()
127 + else:
128 + raise Exception("Invalid attention function")
129 +
130 + # Build regressor
131 + self.attn_post = nn.Linear(self.num_heads * self.attn_dim, 64)
132 + self.regressor = nn.Sequential(nn.ReLU(), nn.Linear(64, 1))
133 + self.init_weights()
134 +
135 + def init_weights(self):
136 + for k, m in self.named_modules():
137 + if isinstance(m, nn.Conv2d):
138 + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
139 + if m.bias is not None:
140 + nn.init.constant_(m.bias, 0)
141 + elif isinstance(m, nn.Linear) and "regressor" in k:
142 + m.bias.data.fill_(62.68)
143 + elif isinstance(m, nn.Linear):
144 + nn.init.normal_(m.weight, 0, 0.01)
145 + nn.init.constant_(m.bias, 0)
146 +
147 + def encode(self, x):
148 +
149 + B, C, H, W, D = x.size()
150 + if self.slice_dim == 1:
151 + new_input = torch.cat([x[:, :, i, :, :] for i in range(H)], dim=0)
152 + encoding = self.encoder(new_input)
153 + encoding = self.post_proc(encoding)
154 + encoding = torch.cat([i.unsqueeze(2) for i in torch.split(encoding, B, dim=0)], dim=2)
155 + # note: squeezing is bad because batch dim can be dropped
156 + encoding = encoding.squeeze(4).squeeze(3)
157 + elif self.slice_dim == 2:
158 + new_input = torch.cat([x[:, :, :, i, :] for i in range(W)], dim=0)
159 + encoding = self.encoder(new_input)
160 + encoding = self.post_proc(encoding)
161 + encoding = torch.cat([i.unsqueeze(3) for i in torch.split(encoding, B, dim=0)], dim=3)
162 + # note: squeezing is bad because batch dim can be dropped
163 + encoding = encoding.squeeze(4).squeeze(2)
164 + elif self.slice_dim == 3:
165 + new_input = torch.cat([x[:, :, :, :, i] for i in range(D)], dim=0)
166 + encoding = self.encoder(new_input)
167 + encoding = self.post_proc(encoding)
168 + encoding = torch.cat([i.unsqueeze(4) for i in torch.split(encoding, B, dim=0)], dim=4)
169 + # note: squeezing is bad because batch dim can be dropped
170 + encoding = encoding.squeeze(3).squeeze(2)
171 + else:
172 + raise Exception("Invalid slice dim")
173 +
174 + # swap dims for input to attention
175 + encoding = encoding.permute((0, 2, 1))
176 + encoding, attention = self.pooled_attention(encoding)
177 + return encoding.squeeze(1), attention
178 +
179 + def forward(self, x):
180 + embedding, attention = self.encode(x)
181 + post = self.attn_post(embedding)
182 + y_pred = self.regressor(post)
183 + return Box({"y_pred": y_pred, "attention": attention})
184 +
185 + def get_attention(self, x):
186 + _, attention = self.encode(x)
187 + return attention
188 +
189 +
190 +def get_arch(*args, **kwargs):
191 + return {"net": MRI_ATTN(*args, **kwargs)}