Showing
1 changed file
with
0 additions
and
46 deletions
src/arch/brain_age_3d.py
deleted
100644 → 0
1 | -""" 3D brain age model""" | ||
2 | -from box import Box | ||
3 | -from torch import nn | ||
4 | -from torch.nn import init | ||
5 | - | ||
6 | - | ||
7 | -def conv_blk(in_channel, out_channel): | ||
8 | - return nn.Sequential( | ||
9 | - nn.Conv3d(in_channel, out_channel, kernel_size=3, stride=1, padding=1), | ||
10 | - nn.InstanceNorm3d(out_channel), nn.MaxPool3d(2, stride=2), nn.ReLU() | ||
11 | - ) | ||
12 | - | ||
13 | - | ||
14 | -class Model(nn.Module): | ||
15 | - def __init__(self): | ||
16 | - super(Model, self).__init__() | ||
17 | - self.conv1 = conv_blk(1, 32) | ||
18 | - self.conv2 = conv_blk(32, 64) | ||
19 | - self.conv3 = conv_blk(64, 128) | ||
20 | - self.conv4 = conv_blk(128, 256) | ||
21 | - self.conv5 = conv_blk(256, 256) | ||
22 | - | ||
23 | - self.conv6 = nn.Sequential(nn.Conv3d(256, 64, kernel_size=1, stride=1), | ||
24 | - nn.InstanceNorm3d(64), nn.ReLU(), | ||
25 | - nn.AvgPool3d(kernel_size=(2, 3, 2))) | ||
26 | - | ||
27 | - self.drop = nn.Dropout3d(p=0.5) | ||
28 | - | ||
29 | - self.output = nn.Conv3d(64, 1, kernel_size=1, stride=1) | ||
30 | - | ||
31 | - init.constant_(self.output.bias, 62.68) | ||
32 | - | ||
33 | - def forward(self, x): | ||
34 | - x = self.conv1(x) | ||
35 | - x = self.conv2(x) | ||
36 | - x = self.conv3(x) | ||
37 | - x = self.conv4(x) | ||
38 | - x = self.conv5(x) | ||
39 | - x = self.conv6(x) | ||
40 | - x = self.drop(x) | ||
41 | - x = self.output(x) | ||
42 | - return Box({"y_pred": x}) | ||
43 | - | ||
44 | - | ||
45 | -def get_arch(*args, **kwargs): | ||
46 | - return {"net": Model()} |
-
Please register or login to post a comment