조현아

rm stratified FAA getBraTS_3

...@@ -3,4 +3,5 @@ tb-nightly ...@@ -3,4 +3,5 @@ tb-nightly
3 torchvision 3 torchvision
4 torch 4 torch
5 hyperopt 5 hyperopt
6 -fire 6 +pillow==6.2.1
7 +natsort
...\ No newline at end of file ...\ No newline at end of file
......
...@@ -6,7 +6,8 @@ import pickle as cp ...@@ -6,7 +6,8 @@ import pickle as cp
6 import glob 6 import glob
7 import numpy as np 7 import numpy as np
8 import pandas as pd 8 import pandas as pd
9 - 9 +from natsort import natsorted
10 +from PIL import Image
10 import torch 11 import torch
11 import torchvision 12 import torchvision
12 import torch.nn.functional as F 13 import torch.nn.functional as F
...@@ -16,12 +17,14 @@ from torch.utils.data import Subset ...@@ -16,12 +17,14 @@ from torch.utils.data import Subset
16 from torch.utils.data import Dataset, DataLoader 17 from torch.utils.data import Dataset, DataLoader
17 18
18 from sklearn.model_selection import StratifiedShuffleSplit 19 from sklearn.model_selection import StratifiedShuffleSplit
20 +from sklearn.model_selection import train_test_split
21 +from sklearn.model_selection import KFold
19 22
20 from networks import basenet 23 from networks import basenet
21 24
22 25
23 -TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/' 26 +TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame'
24 -VAL_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame/' 27 +VAL_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame'
25 28
26 TRAIN_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame.csv' 29 TRAIN_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame.csv'
27 VAL_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame.csv' 30 VAL_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame.csv'
...@@ -32,16 +35,31 @@ current_epoch = 0 ...@@ -32,16 +35,31 @@ current_epoch = 0
32 def split_dataset(args, dataset, k): 35 def split_dataset(args, dataset, k):
33 # load dataset 36 # load dataset
34 X = list(range(len(dataset))) 37 X = list(range(len(dataset)))
35 - Y = dataset 38 + #Y = dataset.targets
36 39
37 # split to k-fold 40 # split to k-fold
38 - assert len(X) == len(Y) 41 + # assert len(X) == len(Y)
39 42
40 def _it_to_list(_it): 43 def _it_to_list(_it):
41 return list(zip(*list(_it))) 44 return list(zip(*list(_it)))
42 45
43 - sss = StratifiedShuffleSplit(n_splits=k, random_state=args.seed, test_size=0.1) 46 + # sss = StratifiedShuffleSplit(n_splits=k, random_state=args.seed, test_size=0.1)
44 - Dm_indexes, Da_indexes = _it_to_list(sss.split(X, Y)) 47 + # Dm_indexes, Da_indexes = _it_to_list(sss.split(X, Y))
48 +
49 + x_train = []
50 + x_test = []
51 +
52 +
53 + for i in range(k):
54 + xtr, xte = train_test_split(X, random_state=args.seed, test_size=0.1)
55 + x_train.append(xtr)
56 + x_test.append(xte)
57 +
58 + #kf = KFold(n_splits=k, random_state=args.seed, test)
59 + #kf.split(x_train)
60 +
61 + Dm_indexes, Da_indexes = np.array(x_train), np.array(x_test)
62 +
45 63
46 return Dm_indexes, Da_indexes 64 return Dm_indexes, Da_indexes
47 65
...@@ -154,20 +172,27 @@ class CustomDataset(Dataset): ...@@ -154,20 +172,27 @@ class CustomDataset(Dataset):
154 def __init__(self, path, target_path, transform = None): 172 def __init__(self, path, target_path, transform = None):
155 self.path = path 173 self.path = path
156 self.transform = transform 174 self.transform = transform
157 - #self.img = np.load(path) 175 + #self.imgpath = glob.glob(path + '/*.png'
158 - self.img = glob.glob(path + '/*.png') 176 + #self.img = np.expand_dims(np.load(glob.glob(path + '/*.png'), axis = 3)
159 - self.len = len(self.img) 177 + self.imgs = natsorted(os.listdir(path))
178 + self.len = len(self.imgs)
179 + #self.len = self.img.shape[0]
160 self.targets = pd.read_csv(target_path, header = None) 180 self.targets = pd.read_csv(target_path, header = None)
161 181
162 def __len__(self): 182 def __len__(self):
163 return self.len 183 return self.len
164 184
165 def __getitem__(self, idx): 185 def __getitem__(self, idx):
166 - img, targets = self.img[idx], self.targets[idx] 186 + #img, targets = self.img[idx], self.targets[idx]
187 + img_loc = os.path.join(self.path, self.imgs[idx])
188 + #img = self.img[idx]
189 + image = Image.open(img_loc)
167 190
168 if self.transform is not None: 191 if self.transform is not None:
169 - img = self.transform(img) 192 + #img = self.transform(img)
170 - return img, targets 193 + tensor_image = self.transform(image)
194 + #return img, targets
195 + return tensor_image
171 196
172 def get_dataset(args, transform, split='train'): 197 def get_dataset(args, transform, split='train'):
173 assert split in ['train', 'val', 'test', 'trainval'] 198 assert split in ['train', 'val', 'test', 'trainval']
...@@ -309,6 +334,8 @@ def get_valid_transform(args, model): ...@@ -309,6 +334,8 @@ def get_valid_transform(args, model):
309 334
310 def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer, device=None): 335 def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer, device=None):
311 model.train() 336 model.train()
337 + print('\nBatch\n', batch)
338 + print('\nBatch size\n', batch.size())
312 images, target = batch 339 images, target = batch
313 340
314 if device: 341 if device:
......