Showing
1 changed file
with
46 additions
and
0 deletions
2DCNN/src/arch/brain_age_3d.py
0 → 100644
1 | +""" 3D brain age model""" | ||
2 | +from box import Box | ||
3 | +from torch import nn | ||
4 | +from torch.nn import init | ||
5 | + | ||
6 | + | ||
7 | +def conv_blk(in_channel, out_channel): | ||
8 | + return nn.Sequential( | ||
9 | + nn.Conv3d(in_channel, out_channel, kernel_size=3, stride=1, padding=1), | ||
10 | + nn.InstanceNorm3d(out_channel), nn.MaxPool3d(2, stride=2), nn.ReLU() | ||
11 | + ) | ||
12 | + | ||
13 | + | ||
14 | +class Model(nn.Module): | ||
15 | + def __init__(self): | ||
16 | + super(Model, self).__init__() | ||
17 | + self.conv1 = conv_blk(1, 32) | ||
18 | + self.conv2 = conv_blk(32, 64) | ||
19 | + self.conv3 = conv_blk(64, 128) | ||
20 | + self.conv4 = conv_blk(128, 256) | ||
21 | + self.conv5 = conv_blk(256, 256) | ||
22 | + | ||
23 | + self.conv6 = nn.Sequential(nn.Conv3d(256, 64, kernel_size=1, stride=1), | ||
24 | + nn.InstanceNorm3d(64), nn.ReLU(), | ||
25 | + nn.AvgPool3d(kernel_size=(2, 3, 2))) | ||
26 | + | ||
27 | + self.drop = nn.Dropout3d(p=0.5) | ||
28 | + | ||
29 | + self.output = nn.Conv3d(64, 1, kernel_size=1, stride=1) | ||
30 | + | ||
31 | + init.constant_(self.output.bias, 62.68) | ||
32 | + | ||
33 | + def forward(self, x): | ||
34 | + x = self.conv1(x) | ||
35 | + x = self.conv2(x) | ||
36 | + x = self.conv3(x) | ||
37 | + x = self.conv4(x) | ||
38 | + x = self.conv5(x) | ||
39 | + x = self.conv6(x) | ||
40 | + x = self.drop(x) | ||
41 | + x = self.output(x) | ||
42 | + return Box({"y_pred": x}) | ||
43 | + | ||
44 | + | ||
45 | +def get_arch(*args, **kwargs): | ||
46 | + return {"net": Model()} |
-
Please register or login to post a comment