Hyunji

Delete brain_age_3d.py

1 -""" 3D brain age model"""
2 -from box import Box
3 -from torch import nn
4 -from torch.nn import init
5 -
6 -
7 -def conv_blk(in_channel, out_channel):
8 - return nn.Sequential(
9 - nn.Conv3d(in_channel, out_channel, kernel_size=3, stride=1, padding=1),
10 - nn.InstanceNorm3d(out_channel), nn.MaxPool3d(2, stride=2), nn.ReLU()
11 - )
12 -
13 -
14 -class Model(nn.Module):
15 - def __init__(self):
16 - super(Model, self).__init__()
17 - self.conv1 = conv_blk(1, 32)
18 - self.conv2 = conv_blk(32, 64)
19 - self.conv3 = conv_blk(64, 128)
20 - self.conv4 = conv_blk(128, 256)
21 - self.conv5 = conv_blk(256, 256)
22 -
23 - self.conv6 = nn.Sequential(nn.Conv3d(256, 64, kernel_size=1, stride=1),
24 - nn.InstanceNorm3d(64), nn.ReLU(),
25 - nn.AvgPool3d(kernel_size=(2, 3, 2)))
26 -
27 - self.drop = nn.Dropout3d(p=0.5)
28 -
29 - self.output = nn.Conv3d(64, 1, kernel_size=1, stride=1)
30 -
31 - init.constant_(self.output.bias, 62.68)
32 -
33 - def forward(self, x):
34 - x = self.conv1(x)
35 - x = self.conv2(x)
36 - x = self.conv3(x)
37 - x = self.conv4(x)
38 - x = self.conv5(x)
39 - x = self.conv6(x)
40 - x = self.drop(x)
41 - x = self.output(x)
42 - return Box({"y_pred": x})
43 -
44 -
45 -def get_arch(*args, **kwargs):
46 - return {"net": Model()}