조현아

add targets FAA getBraTS_4

...@@ -5,3 +5,4 @@ torch ...@@ -5,3 +5,4 @@ torch
5 hyperopt 5 hyperopt
6 pillow==6.2.1 6 pillow==6.2.1
7 natsort 7 natsort
8 +fire
...\ No newline at end of file ...\ No newline at end of file
......
...@@ -22,9 +22,9 @@ from sklearn.model_selection import KFold ...@@ -22,9 +22,9 @@ from sklearn.model_selection import KFold
22 22
23 from networks import basenet 23 from networks import basenet
24 24
25 - 25 +DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/'
26 -TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame' 26 +TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/'
27 -VAL_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame' 27 +VAL_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame/'
28 28
29 TRAIN_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame.csv' 29 TRAIN_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame.csv'
30 VAL_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame.csv' 30 VAL_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame.csv'
...@@ -35,7 +35,10 @@ current_epoch = 0 ...@@ -35,7 +35,10 @@ current_epoch = 0
35 def split_dataset(args, dataset, k): 35 def split_dataset(args, dataset, k):
36 # load dataset 36 # load dataset
37 X = list(range(len(dataset))) 37 X = list(range(len(dataset)))
38 - #Y = dataset.targets 38 + Y = dataset.targets
39 + #Y = [0]* len(X)
40 +
41 + #print("X:\n", type(X), np.shape(X), '\n', X, '\n')
39 42
40 # split to k-fold 43 # split to k-fold
41 # assert len(X) == len(Y) 44 # assert len(X) == len(Y)
...@@ -43,26 +46,49 @@ def split_dataset(args, dataset, k): ...@@ -43,26 +46,49 @@ def split_dataset(args, dataset, k):
43 def _it_to_list(_it): 46 def _it_to_list(_it):
44 return list(zip(*list(_it))) 47 return list(zip(*list(_it)))
45 48
46 - # sss = StratifiedShuffleSplit(n_splits=k, random_state=args.seed, test_size=0.1) 49 + sss = StratifiedShuffleSplit(n_splits=k, random_state=args.seed, test_size=0.1)
47 - # Dm_indexes, Da_indexes = _it_to_list(sss.split(X, Y)) 50 + Dm_indexes, Da_indexes = _it_to_list(sss.split(X, Y))
51 +
52 + # print(type(Dm_indexes), np.shape(Dm_indexes))
53 + # print("DM\n", len(Dm_indexes), Dm_indexes, "\nDA\n", len(Da_indexes),Da_indexes)
54 +
55 +
56 + return Dm_indexes, Da_indexes
57 +
58 +def split_dataset2222(args, dataset, k):
59 + # load dataset
60 + X = list(range(len(dataset)))
61 +
62 + # split to k-fold
63 + #assert len(X) == len(Y)
48 64
49 - x_train = [] 65 + def _it_to_list(_it):
50 - x_test = [] 66 + return list(zip(*list(_it)))
51 67
68 + x_train = ()
69 + x_test = ()
52 70
53 for i in range(k): 71 for i in range(k):
54 - xtr, xte = train_test_split(X, random_state=args.seed, test_size=0.1) 72 + #xtr, xte = train_test_split(X, random_state=args.seed, test_size=0.1)
55 - x_train.append(xtr) 73 + xtr, xte = train_test_split(X, random_state=None, test_size=0.1)
56 - x_test.append(xte) 74 + x_train.append(np.array(xtr))
75 + x_test.append(np.array(xte))
57 76
58 - #kf = KFold(n_splits=k, random_state=args.seed, test) 77 + y_train = np.array([0]* len(x_train))
59 - #kf.split(x_train) 78 + y_test = np.array([0]* len(x_test))
60 79
61 - Dm_indexes, Da_indexes = np.array(x_train), np.array(x_test) 80 + x_train = tuple(x_train)
81 + x_test = tuple(x_test)
62 82
83 + trainset = (zip(x_train, y_train),)
84 + testset = (zip(x_test, y_test),)
63 85
64 - return Dm_indexes, Da_indexes 86 + Dm_indexes, Da_indexes = trainset, testset
65 87
88 + print(type(Dm_indexes), np.shape(Dm_indexes))
89 + print("DM\n", np.shape(Dm_indexes), Dm_indexes, "\nDA\n", np.shape(Da_indexes), Da_indexes)
90 +
91 + return Dm_indexes, Da_indexes
66 92
67 def concat_image_features(image, features, max_features=3): 93 def concat_image_features(image, features, max_features=3):
68 _, h, w = image.shape 94 _, h, w = image.shape
...@@ -169,22 +195,24 @@ def select_scheduler(args, optimizer): ...@@ -169,22 +195,24 @@ def select_scheduler(args, optimizer):
169 195
170 196
171 class CustomDataset(Dataset): 197 class CustomDataset(Dataset):
172 - def __init__(self, path, target_path, transform = None): 198 + def __init__(self, path, transform = None):
173 self.path = path 199 self.path = path
174 self.transform = transform 200 self.transform = transform
175 #self.imgpath = glob.glob(path + '/*.png' 201 #self.imgpath = glob.glob(path + '/*.png'
176 - #self.img = np.expand_dims(np.load(glob.glob(path + '/*.png'), axis = 3)
177 self.imgs = natsorted(os.listdir(path)) 202 self.imgs = natsorted(os.listdir(path))
178 self.len = len(self.imgs) 203 self.len = len(self.imgs)
179 #self.len = self.img.shape[0] 204 #self.len = self.img.shape[0]
180 - self.targets = pd.read_csv(target_path, header = None) 205 + self.targets = [0]* self.len
181 206
182 def __len__(self): 207 def __len__(self):
183 return self.len 208 return self.len
184 209
185 def __getitem__(self, idx): 210 def __getitem__(self, idx):
211 + # print("\n\nIDX: ", idx, '\n', type(idx), '\n')
212 + # print("\n\nimgs[idx]: ", self.imgs[idx], '\n', type(self.imgs[idx]), '\n')
186 #img, targets = self.img[idx], self.targets[idx] 213 #img, targets = self.img[idx], self.targets[idx]
187 img_loc = os.path.join(self.path, self.imgs[idx]) 214 img_loc = os.path.join(self.path, self.imgs[idx])
215 + targets = self.targets[idx]
188 #img = self.img[idx] 216 #img = self.img[idx]
189 image = Image.open(img_loc) 217 image = Image.open(img_loc)
190 218
...@@ -192,7 +220,7 @@ class CustomDataset(Dataset): ...@@ -192,7 +220,7 @@ class CustomDataset(Dataset):
192 #img = self.transform(img) 220 #img = self.transform(img)
193 tensor_image = self.transform(image) 221 tensor_image = self.transform(image)
194 #return img, targets 222 #return img, targets
195 - return tensor_image 223 + return tensor_image, targets
196 224
197 def get_dataset(args, transform, split='train'): 225 def get_dataset(args, transform, split='train'):
198 assert split in ['train', 'val', 'test', 'trainval'] 226 assert split in ['train', 'val', 'test', 'trainval']
...@@ -224,9 +252,9 @@ def get_dataset(args, transform, split='train'): ...@@ -224,9 +252,9 @@ def get_dataset(args, transform, split='train'):
224 252
225 elif args.dataset == 'BraTS': 253 elif args.dataset == 'BraTS':
226 if split in ['train']: 254 if split in ['train']:
227 - dataset = CustomDataset(TRAIN_DATASET_PATH, TRAIN_TARGET_PATH, transform=transform) 255 + dataset = CustomDataset(TRAIN_DATASET_PATH, transform=transform)
228 else: 256 else:
229 - dataset = CustomDataset(VAL_DATASET_PATH, VAL_TARGET_PATH, transform=transform) 257 + dataset = CustomDataset(VAL_DATASET_PATH, transform=transform)
230 258
231 259
232 else: 260 else:
...@@ -250,6 +278,7 @@ def get_inf_dataloader(args, dataset): ...@@ -250,6 +278,7 @@ def get_inf_dataloader(args, dataset):
250 278
251 while True: 279 while True:
252 try: 280 try:
281 + #print("batch=dataloader:\n", batch, '\n')
253 batch = next(data_loader) 282 batch = next(data_loader)
254 283
255 except StopIteration: 284 except StopIteration:
...@@ -334,8 +363,7 @@ def get_valid_transform(args, model): ...@@ -334,8 +363,7 @@ def get_valid_transform(args, model):
334 363
335 def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer, device=None): 364 def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer, device=None):
336 model.train() 365 model.train()
337 - print('\nBatch\n', batch) 366 + #print('\nBatch\n', batch)
338 - print('\nBatch size\n', batch.size())
339 images, target = batch 367 images, target = batch
340 368
341 if device: 369 if device:
......