Toggle navigation
Toggle navigation
This project
Loading...
Sign in
2020-1-capstone-design2
/
2016104167
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Snippets
Network
Create a new issue
Builds
Commits
Issue Boards
Authored by
조현아
2020-03-30 23:10:12 +0900
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Commit
165fdd319742d7d286ae40214c9484c011edca78
165fdd31
1 parent
165bb19a
rm stratified FAA getBraTS_3
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
43 additions
and
15 deletions
code/FAA2/requirements.txt
code/FAA2/utils.py
code/FAA2/requirements.txt
View file @
165fdd3
...
...
@@ -3,4 +3,5 @@ tb-nightly
torchvision
torch
hyperopt
fire
pillow==6.2.1
natsort
\ No newline at end of file
...
...
code/FAA2/utils.py
View file @
165fdd3
...
...
@@ -6,7 +6,8 @@ import pickle as cp
import
glob
import
numpy
as
np
import
pandas
as
pd
from
natsort
import
natsorted
from
PIL
import
Image
import
torch
import
torchvision
import
torch.nn.functional
as
F
...
...
@@ -16,12 +17,14 @@ from torch.utils.data import Subset
from
torch.utils.data
import
Dataset
,
DataLoader
from
sklearn.model_selection
import
StratifiedShuffleSplit
from
sklearn.model_selection
import
train_test_split
from
sklearn.model_selection
import
KFold
from
networks
import
basenet
TRAIN_DATASET_PATH
=
'/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame
/
'
VAL_DATASET_PATH
=
'/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame
/
'
TRAIN_DATASET_PATH
=
'/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame'
VAL_DATASET_PATH
=
'/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame'
TRAIN_TARGET_PATH
=
'/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame.csv'
VAL_TARGET_PATH
=
'/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame.csv'
...
...
@@ -32,16 +35,31 @@ current_epoch = 0
def
split_dataset
(
args
,
dataset
,
k
):
# load dataset
X
=
list
(
range
(
len
(
dataset
)))
Y
=
dataset
#Y = dataset.targets
# split to k-fold
assert
len
(
X
)
==
len
(
Y
)
#
assert len(X) == len(Y)
def
_it_to_list
(
_it
):
return
list
(
zip
(
*
list
(
_it
)))
sss
=
StratifiedShuffleSplit
(
n_splits
=
k
,
random_state
=
args
.
seed
,
test_size
=
0.1
)
Dm_indexes
,
Da_indexes
=
_it_to_list
(
sss
.
split
(
X
,
Y
))
# sss = StratifiedShuffleSplit(n_splits=k, random_state=args.seed, test_size=0.1)
# Dm_indexes, Da_indexes = _it_to_list(sss.split(X, Y))
x_train
=
[]
x_test
=
[]
for
i
in
range
(
k
):
xtr
,
xte
=
train_test_split
(
X
,
random_state
=
args
.
seed
,
test_size
=
0.1
)
x_train
.
append
(
xtr
)
x_test
.
append
(
xte
)
#kf = KFold(n_splits=k, random_state=args.seed, test)
#kf.split(x_train)
Dm_indexes
,
Da_indexes
=
np
.
array
(
x_train
),
np
.
array
(
x_test
)
return
Dm_indexes
,
Da_indexes
...
...
@@ -154,20 +172,27 @@ class CustomDataset(Dataset):
def
__init__
(
self
,
path
,
target_path
,
transform
=
None
):
self
.
path
=
path
self
.
transform
=
transform
#self.img = np.load(path)
self
.
img
=
glob
.
glob
(
path
+
'/*.png'
)
self
.
len
=
len
(
self
.
img
)
#self.imgpath = glob.glob(path + '/*.png'
#self.img = np.expand_dims(np.load(glob.glob(path + '/*.png'), axis = 3)
self
.
imgs
=
natsorted
(
os
.
listdir
(
path
))
self
.
len
=
len
(
self
.
imgs
)
#self.len = self.img.shape[0]
self
.
targets
=
pd
.
read_csv
(
target_path
,
header
=
None
)
def
__len__
(
self
):
return
self
.
len
def
__getitem__
(
self
,
idx
):
img
,
targets
=
self
.
img
[
idx
],
self
.
targets
[
idx
]
#img, targets = self.img[idx], self.targets[idx]
img_loc
=
os
.
path
.
join
(
self
.
path
,
self
.
imgs
[
idx
])
#img = self.img[idx]
image
=
Image
.
open
(
img_loc
)
if
self
.
transform
is
not
None
:
img
=
self
.
transform
(
img
)
return
img
,
targets
#img = self.transform(img)
tensor_image
=
self
.
transform
(
image
)
#return img, targets
return
tensor_image
def
get_dataset
(
args
,
transform
,
split
=
'train'
):
assert
split
in
[
'train'
,
'val'
,
'test'
,
'trainval'
]
...
...
@@ -309,6 +334,8 @@ def get_valid_transform(args, model):
def
train_step
(
args
,
model
,
optimizer
,
scheduler
,
criterion
,
batch
,
step
,
writer
,
device
=
None
):
model
.
train
()
print
(
'
\n
Batch
\n
'
,
batch
)
print
(
'
\n
Batch size
\n
'
,
batch
.
size
())
images
,
target
=
batch
if
device
:
...
...
Please
register
or
login
to post a comment