Hyunji

brain age 3d

1 +""" 3D brain age model"""
2 +from box import Box
3 +from torch import nn
4 +from torch.nn import init
5 +
6 +
7 +def conv_blk(in_channel, out_channel):
8 + return nn.Sequential(
9 + nn.Conv3d(in_channel, out_channel, kernel_size=3, stride=1, padding=1),
10 + nn.InstanceNorm3d(out_channel), nn.MaxPool3d(2, stride=2), nn.ReLU()
11 + )
12 +
13 +
14 +class Model(nn.Module):
15 + def __init__(self):
16 + super(Model, self).__init__()
17 + self.conv1 = conv_blk(1, 32)
18 + self.conv2 = conv_blk(32, 64)
19 + self.conv3 = conv_blk(64, 128)
20 + self.conv4 = conv_blk(128, 256)
21 + self.conv5 = conv_blk(256, 256)
22 +
23 + self.conv6 = nn.Sequential(nn.Conv3d(256, 64, kernel_size=1, stride=1),
24 + nn.InstanceNorm3d(64), nn.ReLU(),
25 + nn.AvgPool3d(kernel_size=(2, 3, 2)))
26 +
27 + self.drop = nn.Dropout3d(p=0.5)
28 +
29 + self.output = nn.Conv3d(64, 1, kernel_size=1, stride=1)
30 +
31 + init.constant_(self.output.bias, 62.68)
32 +
33 + def forward(self, x):
34 + x = self.conv1(x)
35 + x = self.conv2(x)
36 + x = self.conv3(x)
37 + x = self.conv4(x)
38 + x = self.conv5(x)
39 + x = self.conv6(x)
40 + x = self.drop(x)
41 + x = self.output(x)
42 + return Box({"y_pred": x})
43 +
44 +
45 +def get_arch(*args, **kwargs):
46 + return {"net": Model()}