Showing
2 changed files
with
42 additions
and
14 deletions
| ... | @@ -6,7 +6,8 @@ import pickle as cp | ... | @@ -6,7 +6,8 @@ import pickle as cp |
| 6 | import glob | 6 | import glob |
| 7 | import numpy as np | 7 | import numpy as np |
| 8 | import pandas as pd | 8 | import pandas as pd |
| 9 | - | 9 | +from natsort import natsorted |
| 10 | +from PIL import Image | ||
| 10 | import torch | 11 | import torch |
| 11 | import torchvision | 12 | import torchvision |
| 12 | import torch.nn.functional as F | 13 | import torch.nn.functional as F |
| ... | @@ -16,12 +17,14 @@ from torch.utils.data import Subset | ... | @@ -16,12 +17,14 @@ from torch.utils.data import Subset |
| 16 | from torch.utils.data import Dataset, DataLoader | 17 | from torch.utils.data import Dataset, DataLoader |
| 17 | 18 | ||
| 18 | from sklearn.model_selection import StratifiedShuffleSplit | 19 | from sklearn.model_selection import StratifiedShuffleSplit |
| 20 | +from sklearn.model_selection import train_test_split | ||
| 21 | +from sklearn.model_selection import KFold | ||
| 19 | 22 | ||
| 20 | from networks import basenet | 23 | from networks import basenet |
| 21 | 24 | ||
| 22 | 25 | ||
| 23 | -TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/' | 26 | +TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame' |
| 24 | -VAL_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame/' | 27 | +VAL_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame' |
| 25 | 28 | ||
| 26 | TRAIN_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame.csv' | 29 | TRAIN_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame.csv' |
| 27 | VAL_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame.csv' | 30 | VAL_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame.csv' |
| ... | @@ -32,16 +35,31 @@ current_epoch = 0 | ... | @@ -32,16 +35,31 @@ current_epoch = 0 |
| 32 | def split_dataset(args, dataset, k): | 35 | def split_dataset(args, dataset, k): |
| 33 | # load dataset | 36 | # load dataset |
| 34 | X = list(range(len(dataset))) | 37 | X = list(range(len(dataset))) |
| 35 | - Y = dataset | 38 | + #Y = dataset.targets |
| 36 | 39 | ||
| 37 | # split to k-fold | 40 | # split to k-fold |
| 38 | - assert len(X) == len(Y) | 41 | + # assert len(X) == len(Y) |
| 39 | 42 | ||
| 40 | def _it_to_list(_it): | 43 | def _it_to_list(_it): |
| 41 | return list(zip(*list(_it))) | 44 | return list(zip(*list(_it))) |
| 42 | 45 | ||
| 43 | - sss = StratifiedShuffleSplit(n_splits=k, random_state=args.seed, test_size=0.1) | 46 | + # sss = StratifiedShuffleSplit(n_splits=k, random_state=args.seed, test_size=0.1) |
| 44 | - Dm_indexes, Da_indexes = _it_to_list(sss.split(X, Y)) | 47 | + # Dm_indexes, Da_indexes = _it_to_list(sss.split(X, Y)) |
| 48 | + | ||
| 49 | + x_train = [] | ||
| 50 | + x_test = [] | ||
| 51 | + | ||
| 52 | + | ||
| 53 | + for i in range(k): | ||
| 54 | + xtr, xte = train_test_split(X, random_state=args.seed, test_size=0.1) | ||
| 55 | + x_train.append(xtr) | ||
| 56 | + x_test.append(xte) | ||
| 57 | + | ||
| 58 | + #kf = KFold(n_splits=k, random_state=args.seed, test) | ||
| 59 | + #kf.split(x_train) | ||
| 60 | + | ||
| 61 | + Dm_indexes, Da_indexes = np.array(x_train), np.array(x_test) | ||
| 62 | + | ||
| 45 | 63 | ||
| 46 | return Dm_indexes, Da_indexes | 64 | return Dm_indexes, Da_indexes |
| 47 | 65 | ||
| ... | @@ -154,20 +172,27 @@ class CustomDataset(Dataset): | ... | @@ -154,20 +172,27 @@ class CustomDataset(Dataset): |
| 154 | def __init__(self, path, target_path, transform = None): | 172 | def __init__(self, path, target_path, transform = None): |
| 155 | self.path = path | 173 | self.path = path |
| 156 | self.transform = transform | 174 | self.transform = transform |
| 157 | - #self.img = np.load(path) | 175 | + #self.imgpath = glob.glob(path + '/*.png' |
| 158 | - self.img = glob.glob(path + '/*.png') | 176 | + #self.img = np.expand_dims(np.load(glob.glob(path + '/*.png'), axis = 3) |
| 159 | - self.len = len(self.img) | 177 | + self.imgs = natsorted(os.listdir(path)) |
| 178 | + self.len = len(self.imgs) | ||
| 179 | + #self.len = self.img.shape[0] | ||
| 160 | self.targets = pd.read_csv(target_path, header = None) | 180 | self.targets = pd.read_csv(target_path, header = None) |
| 161 | 181 | ||
| 162 | def __len__(self): | 182 | def __len__(self): |
| 163 | return self.len | 183 | return self.len |
| 164 | 184 | ||
| 165 | def __getitem__(self, idx): | 185 | def __getitem__(self, idx): |
| 166 | - img, targets = self.img[idx], self.targets[idx] | 186 | + #img, targets = self.img[idx], self.targets[idx] |
| 187 | + img_loc = os.path.join(self.path, self.imgs[idx]) | ||
| 188 | + #img = self.img[idx] | ||
| 189 | + image = Image.open(img_loc) | ||
| 167 | 190 | ||
| 168 | if self.transform is not None: | 191 | if self.transform is not None: |
| 169 | - img = self.transform(img) | 192 | + #img = self.transform(img) |
| 170 | - return img, targets | 193 | + tensor_image = self.transform(image) |
| 194 | + #return img, targets | ||
| 195 | + return tensor_image | ||
| 171 | 196 | ||
| 172 | def get_dataset(args, transform, split='train'): | 197 | def get_dataset(args, transform, split='train'): |
| 173 | assert split in ['train', 'val', 'test', 'trainval'] | 198 | assert split in ['train', 'val', 'test', 'trainval'] |
| ... | @@ -309,6 +334,8 @@ def get_valid_transform(args, model): | ... | @@ -309,6 +334,8 @@ def get_valid_transform(args, model): |
| 309 | 334 | ||
| 310 | def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer, device=None): | 335 | def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer, device=None): |
| 311 | model.train() | 336 | model.train() |
| 337 | + print('\nBatch\n', batch) | ||
| 338 | + print('\nBatch size\n', batch.size()) | ||
| 312 | images, target = batch | 339 | images, target = batch |
| 313 | 340 | ||
| 314 | if device: | 341 | if device: | ... | ... |
-
Please register or login to post a comment