Toggle navigation
Toggle navigation
This project
Loading...
Sign in
2020-1-capstone-design2
/
2016104167
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Snippets
Network
Create a new issue
Builds
Commits
Issue Boards
Authored by
조현아
2020-03-31 17:09:31 +0900
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Commit
c31b7168c216c6d11ea5f33c9874df3098717d72
c31b7168
1 parent
165fdd31
add targets FAA getBraTS_4
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
54 additions
and
25 deletions
code/FAA2/requirements.txt
code/FAA2/utils.py
code/FAA2/requirements.txt
View file @
c31b716
...
...
@@ -4,4 +4,5 @@ torchvision
torch
hyperopt
pillow==6.2.1
natsort
\ No newline at end of file
natsort
fire
\ No newline at end of file
...
...
code/FAA2/utils.py
View file @
c31b716
...
...
@@ -22,9 +22,9 @@ from sklearn.model_selection import KFold
from
networks
import
basenet
TRAIN_DATASET_PATH
=
'/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame'
VAL_DATASET_PATH
=
'/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame'
DATASET_PATH
=
'/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/'
TRAIN_DATASET_PATH
=
'/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame
/
'
VAL_DATASET_PATH
=
'/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame
/
'
TRAIN_TARGET_PATH
=
'/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame.csv'
VAL_TARGET_PATH
=
'/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame.csv'
...
...
@@ -35,7 +35,10 @@ current_epoch = 0
def
split_dataset
(
args
,
dataset
,
k
):
# load dataset
X
=
list
(
range
(
len
(
dataset
)))
#Y = dataset.targets
Y
=
dataset
.
targets
#Y = [0]* len(X)
#print("X:\n", type(X), np.shape(X), '\n', X, '\n')
# split to k-fold
# assert len(X) == len(Y)
...
...
@@ -43,26 +46,49 @@ def split_dataset(args, dataset, k):
def
_it_to_list
(
_it
):
return
list
(
zip
(
*
list
(
_it
)))
# sss = StratifiedShuffleSplit(n_splits=k, random_state=args.seed, test_size=0.1)
# Dm_indexes, Da_indexes = _it_to_list(sss.split(X, Y))
sss
=
StratifiedShuffleSplit
(
n_splits
=
k
,
random_state
=
args
.
seed
,
test_size
=
0.1
)
Dm_indexes
,
Da_indexes
=
_it_to_list
(
sss
.
split
(
X
,
Y
))
# print(type(Dm_indexes), np.shape(Dm_indexes))
# print("DM\n", len(Dm_indexes), Dm_indexes, "\nDA\n", len(Da_indexes),Da_indexes)
x_train
=
[]
x_test
=
[]
return
Dm_indexes
,
Da_indexes
def
split_dataset2222
(
args
,
dataset
,
k
):
# load dataset
X
=
list
(
range
(
len
(
dataset
)))
# split to k-fold
#assert len(X) == len(Y)
def
_it_to_list
(
_it
):
return
list
(
zip
(
*
list
(
_it
)))
x_train
=
()
x_test
=
()
for
i
in
range
(
k
):
xtr
,
xte
=
train_test_split
(
X
,
random_state
=
args
.
seed
,
test_size
=
0.1
)
x_train
.
append
(
xtr
)
x_test
.
append
(
xte
)
#xtr, xte = train_test_split(X, random_state=args.seed, test_size=0.1)
xtr
,
xte
=
train_test_split
(
X
,
random_state
=
None
,
test_size
=
0.1
)
x_train
.
append
(
np
.
array
(
xtr
))
x_test
.
append
(
np
.
array
(
xte
))
#kf = KFold(n_splits=k, random_state=args.seed, test
)
#kf.split(x_train
)
y_train
=
np
.
array
([
0
]
*
len
(
x_train
)
)
y_test
=
np
.
array
([
0
]
*
len
(
x_test
)
)
Dm_indexes
,
Da_indexes
=
np
.
array
(
x_train
),
np
.
array
(
x_test
)
x_train
=
tuple
(
x_train
)
x_test
=
tuple
(
x_test
)
trainset
=
(
zip
(
x_train
,
y_train
),)
testset
=
(
zip
(
x_test
,
y_test
),)
return
Dm_indexes
,
Da_indexes
Dm_indexes
,
Da_indexes
=
trainset
,
testset
print
(
type
(
Dm_indexes
),
np
.
shape
(
Dm_indexes
))
print
(
"DM
\n
"
,
np
.
shape
(
Dm_indexes
),
Dm_indexes
,
"
\n
DA
\n
"
,
np
.
shape
(
Da_indexes
),
Da_indexes
)
return
Dm_indexes
,
Da_indexes
def
concat_image_features
(
image
,
features
,
max_features
=
3
):
_
,
h
,
w
=
image
.
shape
...
...
@@ -169,22 +195,24 @@ def select_scheduler(args, optimizer):
class
CustomDataset
(
Dataset
):
def
__init__
(
self
,
path
,
t
arget_path
,
t
ransform
=
None
):
def
__init__
(
self
,
path
,
transform
=
None
):
self
.
path
=
path
self
.
transform
=
transform
#self.imgpath = glob.glob(path + '/*.png'
#self.img = np.expand_dims(np.load(glob.glob(path + '/*.png'), axis = 3)
self
.
imgs
=
natsorted
(
os
.
listdir
(
path
))
self
.
len
=
len
(
self
.
imgs
)
#self.len = self.img.shape[0]
self
.
targets
=
pd
.
read_csv
(
target_path
,
header
=
None
)
self
.
targets
=
[
0
]
*
self
.
len
def
__len__
(
self
):
return
self
.
len
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
# print("\n\nIDX: ", idx, '\n', type(idx), '\n')
# print("\n\nimgs[idx]: ", self.imgs[idx], '\n', type(self.imgs[idx]), '\n')
#img, targets = self.img[idx], self.targets[idx]
img_loc
=
os
.
path
.
join
(
self
.
path
,
self
.
imgs
[
idx
])
targets
=
self
.
targets
[
idx
]
#img = self.img[idx]
image
=
Image
.
open
(
img_loc
)
...
...
@@ -192,7 +220,7 @@ class CustomDataset(Dataset):
#img = self.transform(img)
tensor_image
=
self
.
transform
(
image
)
#return img, targets
return
tensor_image
return
tensor_image
,
targets
def
get_dataset
(
args
,
transform
,
split
=
'train'
):
assert
split
in
[
'train'
,
'val'
,
'test'
,
'trainval'
]
...
...
@@ -224,9 +252,9 @@ def get_dataset(args, transform, split='train'):
elif
args
.
dataset
==
'BraTS'
:
if
split
in
[
'train'
]:
dataset
=
CustomDataset
(
TRAIN_DATASET_PATH
,
TRAIN_TARGET_PATH
,
transform
=
transform
)
dataset
=
CustomDataset
(
TRAIN_DATASET_PATH
,
transform
=
transform
)
else
:
dataset
=
CustomDataset
(
VAL_DATASET_PATH
,
VAL_TARGET_PATH
,
transform
=
transform
)
dataset
=
CustomDataset
(
VAL_DATASET_PATH
,
transform
=
transform
)
else
:
...
...
@@ -250,6 +278,7 @@ def get_inf_dataloader(args, dataset):
while
True
:
try
:
#print("batch=dataloader:\n", batch, '\n')
batch
=
next
(
data_loader
)
except
StopIteration
:
...
...
@@ -334,8 +363,7 @@ def get_valid_transform(args, model):
def
train_step
(
args
,
model
,
optimizer
,
scheduler
,
criterion
,
batch
,
step
,
writer
,
device
=
None
):
model
.
train
()
print
(
'
\n
Batch
\n
'
,
batch
)
print
(
'
\n
Batch size
\n
'
,
batch
.
size
())
#print('\nBatch\n', batch)
images
,
target
=
batch
if
device
:
...
...
Please
register
or
login
to post a comment