Toggle navigation
Toggle navigation
This project
Loading...
Sign in
Hyunji
/
CapstoneDesign2021-1
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Snippets
Network
Create a new issue
Builds
Commits
Issue Boards
Authored by
Hyunji
2021-06-21 19:27:07 +0900
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Commit
87efd37ef82d7c0fbaae3e07dca2f57c41f5365a
87efd37e
1 parent
33239aae
brain age slice LSTM
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
129 additions
and
0 deletions
src/arch/brain_age_slice_lstm.py
src/arch/brain_age_slice_lstm.py
0 → 100644
View file @
87efd37
import
torch
from
box
import
Box
from
torch
import
nn
def
encoder_blk
(
in_channels
,
out_channels
):
return
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
,
out_channels
,
3
,
padding
=
1
,
stride
=
1
),
nn
.
InstanceNorm2d
(
out_channels
),
nn
.
MaxPool2d
(
2
,
stride
=
2
),
nn
.
ReLU
()
)
class
MRI_LSTM
(
nn
.
Module
):
def
__init__
(
self
,
lstm_feat_dim
,
lstm_latent_dim
,
slice_dim
,
*
args
,
**
kwargs
):
super
(
MRI_LSTM
,
self
)
.
__init__
()
self
.
input_dim
=
[(
1
,
109
,
91
),
(
91
,
1
,
91
),
(
91
,
109
,
1
)][
slice_dim
-
1
]
self
.
feat_embed_dim
=
lstm_feat_dim
self
.
latent_dim
=
lstm_latent_dim
# Build Encoder
encoder_blocks
=
[
encoder_blk
(
1
,
32
),
encoder_blk
(
32
,
64
),
encoder_blk
(
64
,
128
),
encoder_blk
(
128
,
256
),
encoder_blk
(
256
,
256
)
]
self
.
encoder
=
nn
.
Sequential
(
*
encoder_blocks
)
if
slice_dim
==
1
:
avg
=
nn
.
AvgPool2d
([
3
,
2
])
elif
slice_dim
==
2
:
avg
=
nn
.
AvgPool2d
([
2
,
2
])
elif
slice_dim
==
3
:
avg
=
nn
.
AvgPool2d
([
2
,
3
])
else
:
raise
Exception
(
"Invalid slice dim"
)
self
.
slice_dim
=
slice_dim
# Post processing
self
.
post_proc
=
nn
.
Sequential
(
nn
.
Conv2d
(
256
,
64
,
1
,
stride
=
1
),
nn
.
InstanceNorm2d
(
64
),
nn
.
ReLU
(),
avg
,
nn
.
Dropout
(
p
=
0.5
),
nn
.
Conv2d
(
64
,
self
.
feat_embed_dim
,
1
)
)
# Connect w/ LSTM
self
.
n_layers
=
1
self
.
lstm
=
nn
.
LSTM
(
self
.
feat_embed_dim
,
self
.
latent_dim
,
self
.
n_layers
,
batch_first
=
True
)
# Build regressor
self
.
lstm_post
=
nn
.
Linear
(
self
.
latent_dim
,
64
)
self
.
regressor
=
nn
.
Sequential
(
nn
.
ReLU
(),
nn
.
Linear
(
64
,
1
))
self
.
init_weights
()
def
init_weights
(
self
):
for
k
,
m
in
self
.
named_modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
nn
.
init
.
kaiming_normal_
(
m
.
weight
,
mode
=
"fan_out"
,
nonlinearity
=
"relu"
)
if
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
Linear
)
and
"regressor"
in
k
:
m
.
bias
.
data
.
fill_
(
62.68
)
elif
isinstance
(
m
,
nn
.
Linear
):
nn
.
init
.
normal_
(
m
.
weight
,
0
,
0.01
)
nn
.
init
.
constant_
(
m
.
bias
,
0
)
def
init_hidden
(
self
,
x
):
h_0
=
torch
.
zeros
(
self
.
n_layers
,
x
.
size
(
0
),
self
.
latent_dim
,
device
=
x
.
device
)
c_0
=
torch
.
zeros
(
self
.
n_layers
,
x
.
size
(
0
),
self
.
latent_dim
,
device
=
x
.
device
)
h_0
.
requires_grad
=
True
c_0
.
requires_grad
=
True
return
h_0
,
c_0
def
encode
(
self
,
x
):
h_0
,
c_0
=
self
.
init_hidden
(
x
)
B
,
C
,
H
,
W
,
D
=
x
.
size
()
if
self
.
slice_dim
==
1
:
new_input
=
torch
.
cat
([
x
[:,
:,
i
,
:,
:]
for
i
in
range
(
H
)],
dim
=
0
)
encoding
=
self
.
encoder
(
new_input
)
encoding
=
self
.
post_proc
(
encoding
)
encoding
=
torch
.
cat
([
i
.
unsqueeze
(
2
)
for
i
in
torch
.
split
(
encoding
,
B
,
dim
=
0
)],
dim
=
2
)
# note: squeezing is bad because batch dim can be dropped
encoding
=
encoding
.
squeeze
(
4
)
.
squeeze
(
3
)
elif
self
.
slice_dim
==
2
:
new_input
=
torch
.
cat
([
x
[:,
:,
:,
i
,
:]
for
i
in
range
(
W
)],
dim
=
0
)
encoding
=
self
.
encoder
(
new_input
)
encoding
=
self
.
post_proc
(
encoding
)
encoding
=
torch
.
cat
([
i
.
unsqueeze
(
3
)
for
i
in
torch
.
split
(
encoding
,
B
,
dim
=
0
)],
dim
=
3
)
# note: squeezing is bad because batch dim can be dropped
encoding
=
encoding
.
squeeze
(
4
)
.
squeeze
(
2
)
elif
self
.
slice_dim
==
3
:
new_input
=
torch
.
cat
([
x
[:,
:,
:,
:,
i
]
for
i
in
range
(
D
)],
dim
=
0
)
encoding
=
self
.
encoder
(
new_input
)
encoding
=
self
.
post_proc
(
encoding
)
encoding
=
torch
.
cat
([
i
.
unsqueeze
(
4
)
for
i
in
torch
.
split
(
encoding
,
B
,
dim
=
0
)],
dim
=
4
)
# note: squeezing is bad because batch dim can be dropped
encoding
=
encoding
.
squeeze
(
3
)
.
squeeze
(
2
)
else
:
raise
Exception
(
"Invalid slice dim"
)
# lstm take batch x seq_len x dim
encoding
=
encoding
.
permute
(
0
,
2
,
1
)
_
,
(
encoding
,
_
)
=
self
.
lstm
(
encoding
)
# output is 1 X batch x hidden
encoding
=
encoding
.
squeeze
(
0
)
# pass it to lstm and get encoding
return
encoding
def
forward
(
self
,
x
):
embedding
=
self
.
encode
(
x
)
post
=
self
.
lstm_post
(
embedding
)
y_pred
=
self
.
regressor
(
post
)
return
Box
({
"y_pred"
:
y_pred
})
def
get_arch
(
*
args
,
**
kwargs
):
return
{
"net"
:
MRI_LSTM
(
*
args
,
**
kwargs
)}
Please
register
or
login
to post a comment