Showing
2 changed files
with
42 additions
and
14 deletions
... | @@ -6,7 +6,8 @@ import pickle as cp | ... | @@ -6,7 +6,8 @@ import pickle as cp |
6 | import glob | 6 | import glob |
7 | import numpy as np | 7 | import numpy as np |
8 | import pandas as pd | 8 | import pandas as pd |
9 | - | 9 | +from natsort import natsorted |
10 | +from PIL import Image | ||
10 | import torch | 11 | import torch |
11 | import torchvision | 12 | import torchvision |
12 | import torch.nn.functional as F | 13 | import torch.nn.functional as F |
... | @@ -16,12 +17,14 @@ from torch.utils.data import Subset | ... | @@ -16,12 +17,14 @@ from torch.utils.data import Subset |
16 | from torch.utils.data import Dataset, DataLoader | 17 | from torch.utils.data import Dataset, DataLoader |
17 | 18 | ||
18 | from sklearn.model_selection import StratifiedShuffleSplit | 19 | from sklearn.model_selection import StratifiedShuffleSplit |
20 | +from sklearn.model_selection import train_test_split | ||
21 | +from sklearn.model_selection import KFold | ||
19 | 22 | ||
20 | from networks import basenet | 23 | from networks import basenet |
21 | 24 | ||
22 | 25 | ||
23 | -TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/' | 26 | +TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame' |
24 | -VAL_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame/' | 27 | +VAL_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame' |
25 | 28 | ||
26 | TRAIN_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame.csv' | 29 | TRAIN_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame.csv' |
27 | VAL_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame.csv' | 30 | VAL_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame.csv' |
... | @@ -32,16 +35,31 @@ current_epoch = 0 | ... | @@ -32,16 +35,31 @@ current_epoch = 0 |
32 | def split_dataset(args, dataset, k): | 35 | def split_dataset(args, dataset, k): |
33 | # load dataset | 36 | # load dataset |
34 | X = list(range(len(dataset))) | 37 | X = list(range(len(dataset))) |
35 | - Y = dataset | 38 | + #Y = dataset.targets |
36 | 39 | ||
37 | # split to k-fold | 40 | # split to k-fold |
38 | - assert len(X) == len(Y) | 41 | + # assert len(X) == len(Y) |
39 | 42 | ||
40 | def _it_to_list(_it): | 43 | def _it_to_list(_it): |
41 | return list(zip(*list(_it))) | 44 | return list(zip(*list(_it))) |
42 | 45 | ||
43 | - sss = StratifiedShuffleSplit(n_splits=k, random_state=args.seed, test_size=0.1) | 46 | + # sss = StratifiedShuffleSplit(n_splits=k, random_state=args.seed, test_size=0.1) |
44 | - Dm_indexes, Da_indexes = _it_to_list(sss.split(X, Y)) | 47 | + # Dm_indexes, Da_indexes = _it_to_list(sss.split(X, Y)) |
48 | + | ||
49 | + x_train = [] | ||
50 | + x_test = [] | ||
51 | + | ||
52 | + | ||
53 | + for i in range(k): | ||
54 | + xtr, xte = train_test_split(X, random_state=args.seed, test_size=0.1) | ||
55 | + x_train.append(xtr) | ||
56 | + x_test.append(xte) | ||
57 | + | ||
58 | + #kf = KFold(n_splits=k, random_state=args.seed, test) | ||
59 | + #kf.split(x_train) | ||
60 | + | ||
61 | + Dm_indexes, Da_indexes = np.array(x_train), np.array(x_test) | ||
62 | + | ||
45 | 63 | ||
46 | return Dm_indexes, Da_indexes | 64 | return Dm_indexes, Da_indexes |
47 | 65 | ||
... | @@ -154,20 +172,27 @@ class CustomDataset(Dataset): | ... | @@ -154,20 +172,27 @@ class CustomDataset(Dataset): |
154 | def __init__(self, path, target_path, transform = None): | 172 | def __init__(self, path, target_path, transform = None): |
155 | self.path = path | 173 | self.path = path |
156 | self.transform = transform | 174 | self.transform = transform |
157 | - #self.img = np.load(path) | 175 | + #self.imgpath = glob.glob(path + '/*.png' |
158 | - self.img = glob.glob(path + '/*.png') | 176 | + #self.img = np.expand_dims(np.load(glob.glob(path + '/*.png'), axis = 3) |
159 | - self.len = len(self.img) | 177 | + self.imgs = natsorted(os.listdir(path)) |
178 | + self.len = len(self.imgs) | ||
179 | + #self.len = self.img.shape[0] | ||
160 | self.targets = pd.read_csv(target_path, header = None) | 180 | self.targets = pd.read_csv(target_path, header = None) |
161 | 181 | ||
162 | def __len__(self): | 182 | def __len__(self): |
163 | return self.len | 183 | return self.len |
164 | 184 | ||
165 | def __getitem__(self, idx): | 185 | def __getitem__(self, idx): |
166 | - img, targets = self.img[idx], self.targets[idx] | 186 | + #img, targets = self.img[idx], self.targets[idx] |
187 | + img_loc = os.path.join(self.path, self.imgs[idx]) | ||
188 | + #img = self.img[idx] | ||
189 | + image = Image.open(img_loc) | ||
167 | 190 | ||
168 | if self.transform is not None: | 191 | if self.transform is not None: |
169 | - img = self.transform(img) | 192 | + #img = self.transform(img) |
170 | - return img, targets | 193 | + tensor_image = self.transform(image) |
194 | + #return img, targets | ||
195 | + return tensor_image | ||
171 | 196 | ||
172 | def get_dataset(args, transform, split='train'): | 197 | def get_dataset(args, transform, split='train'): |
173 | assert split in ['train', 'val', 'test', 'trainval'] | 198 | assert split in ['train', 'val', 'test', 'trainval'] |
... | @@ -309,6 +334,8 @@ def get_valid_transform(args, model): | ... | @@ -309,6 +334,8 @@ def get_valid_transform(args, model): |
309 | 334 | ||
310 | def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer, device=None): | 335 | def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer, device=None): |
311 | model.train() | 336 | model.train() |
337 | + print('\nBatch\n', batch) | ||
338 | + print('\nBatch size\n', batch.size()) | ||
312 | images, target = batch | 339 | images, target = batch |
313 | 340 | ||
314 | if device: | 341 | if device: | ... | ... |
-
Please register or login to post a comment