Showing
2 changed files
with
52 additions
and
23 deletions
... | @@ -22,9 +22,9 @@ from sklearn.model_selection import KFold | ... | @@ -22,9 +22,9 @@ from sklearn.model_selection import KFold |
22 | 22 | ||
23 | from networks import basenet | 23 | from networks import basenet |
24 | 24 | ||
25 | - | 25 | +DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/' |
26 | -TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame' | 26 | +TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/' |
27 | -VAL_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame' | 27 | +VAL_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame/' |
28 | 28 | ||
29 | TRAIN_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame.csv' | 29 | TRAIN_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame.csv' |
30 | VAL_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame.csv' | 30 | VAL_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame.csv' |
... | @@ -35,7 +35,10 @@ current_epoch = 0 | ... | @@ -35,7 +35,10 @@ current_epoch = 0 |
35 | def split_dataset(args, dataset, k): | 35 | def split_dataset(args, dataset, k): |
36 | # load dataset | 36 | # load dataset |
37 | X = list(range(len(dataset))) | 37 | X = list(range(len(dataset))) |
38 | - #Y = dataset.targets | 38 | + Y = dataset.targets |
39 | + #Y = [0]* len(X) | ||
40 | + | ||
41 | + #print("X:\n", type(X), np.shape(X), '\n', X, '\n') | ||
39 | 42 | ||
40 | # split to k-fold | 43 | # split to k-fold |
41 | # assert len(X) == len(Y) | 44 | # assert len(X) == len(Y) |
... | @@ -43,26 +46,49 @@ def split_dataset(args, dataset, k): | ... | @@ -43,26 +46,49 @@ def split_dataset(args, dataset, k): |
43 | def _it_to_list(_it): | 46 | def _it_to_list(_it): |
44 | return list(zip(*list(_it))) | 47 | return list(zip(*list(_it))) |
45 | 48 | ||
46 | - # sss = StratifiedShuffleSplit(n_splits=k, random_state=args.seed, test_size=0.1) | 49 | + sss = StratifiedShuffleSplit(n_splits=k, random_state=args.seed, test_size=0.1) |
47 | - # Dm_indexes, Da_indexes = _it_to_list(sss.split(X, Y)) | 50 | + Dm_indexes, Da_indexes = _it_to_list(sss.split(X, Y)) |
51 | + | ||
52 | + # print(type(Dm_indexes), np.shape(Dm_indexes)) | ||
53 | + # print("DM\n", len(Dm_indexes), Dm_indexes, "\nDA\n", len(Da_indexes),Da_indexes) | ||
54 | + | ||
55 | + | ||
56 | + return Dm_indexes, Da_indexes | ||
57 | + | ||
58 | +def split_dataset2222(args, dataset, k): | ||
59 | + # load dataset | ||
60 | + X = list(range(len(dataset))) | ||
61 | + | ||
62 | + # split to k-fold | ||
63 | + #assert len(X) == len(Y) | ||
48 | 64 | ||
49 | - x_train = [] | 65 | + def _it_to_list(_it): |
50 | - x_test = [] | 66 | + return list(zip(*list(_it))) |
51 | 67 | ||
68 | + x_train = () | ||
69 | + x_test = () | ||
52 | 70 | ||
53 | for i in range(k): | 71 | for i in range(k): |
54 | - xtr, xte = train_test_split(X, random_state=args.seed, test_size=0.1) | 72 | + #xtr, xte = train_test_split(X, random_state=args.seed, test_size=0.1) |
55 | - x_train.append(xtr) | 73 | + xtr, xte = train_test_split(X, random_state=None, test_size=0.1) |
56 | - x_test.append(xte) | 74 | + x_train.append(np.array(xtr)) |
75 | + x_test.append(np.array(xte)) | ||
57 | 76 | ||
58 | - #kf = KFold(n_splits=k, random_state=args.seed, test) | 77 | + y_train = np.array([0]* len(x_train)) |
59 | - #kf.split(x_train) | 78 | + y_test = np.array([0]* len(x_test)) |
60 | 79 | ||
61 | - Dm_indexes, Da_indexes = np.array(x_train), np.array(x_test) | 80 | + x_train = tuple(x_train) |
81 | + x_test = tuple(x_test) | ||
62 | 82 | ||
83 | + trainset = (zip(x_train, y_train),) | ||
84 | + testset = (zip(x_test, y_test),) | ||
63 | 85 | ||
64 | - return Dm_indexes, Da_indexes | 86 | + Dm_indexes, Da_indexes = trainset, testset |
65 | 87 | ||
88 | + print(type(Dm_indexes), np.shape(Dm_indexes)) | ||
89 | + print("DM\n", np.shape(Dm_indexes), Dm_indexes, "\nDA\n", np.shape(Da_indexes), Da_indexes) | ||
90 | + | ||
91 | + return Dm_indexes, Da_indexes | ||
66 | 92 | ||
67 | def concat_image_features(image, features, max_features=3): | 93 | def concat_image_features(image, features, max_features=3): |
68 | _, h, w = image.shape | 94 | _, h, w = image.shape |
... | @@ -169,22 +195,24 @@ def select_scheduler(args, optimizer): | ... | @@ -169,22 +195,24 @@ def select_scheduler(args, optimizer): |
169 | 195 | ||
170 | 196 | ||
171 | class CustomDataset(Dataset): | 197 | class CustomDataset(Dataset): |
172 | - def __init__(self, path, target_path, transform = None): | 198 | + def __init__(self, path, transform = None): |
173 | self.path = path | 199 | self.path = path |
174 | self.transform = transform | 200 | self.transform = transform |
175 | #self.imgpath = glob.glob(path + '/*.png' | 201 | #self.imgpath = glob.glob(path + '/*.png' |
176 | - #self.img = np.expand_dims(np.load(glob.glob(path + '/*.png'), axis = 3) | ||
177 | self.imgs = natsorted(os.listdir(path)) | 202 | self.imgs = natsorted(os.listdir(path)) |
178 | self.len = len(self.imgs) | 203 | self.len = len(self.imgs) |
179 | #self.len = self.img.shape[0] | 204 | #self.len = self.img.shape[0] |
180 | - self.targets = pd.read_csv(target_path, header = None) | 205 | + self.targets = [0]* self.len |
181 | 206 | ||
182 | def __len__(self): | 207 | def __len__(self): |
183 | return self.len | 208 | return self.len |
184 | 209 | ||
185 | def __getitem__(self, idx): | 210 | def __getitem__(self, idx): |
211 | + # print("\n\nIDX: ", idx, '\n', type(idx), '\n') | ||
212 | + # print("\n\nimgs[idx]: ", self.imgs[idx], '\n', type(self.imgs[idx]), '\n') | ||
186 | #img, targets = self.img[idx], self.targets[idx] | 213 | #img, targets = self.img[idx], self.targets[idx] |
187 | img_loc = os.path.join(self.path, self.imgs[idx]) | 214 | img_loc = os.path.join(self.path, self.imgs[idx]) |
215 | + targets = self.targets[idx] | ||
188 | #img = self.img[idx] | 216 | #img = self.img[idx] |
189 | image = Image.open(img_loc) | 217 | image = Image.open(img_loc) |
190 | 218 | ||
... | @@ -192,7 +220,7 @@ class CustomDataset(Dataset): | ... | @@ -192,7 +220,7 @@ class CustomDataset(Dataset): |
192 | #img = self.transform(img) | 220 | #img = self.transform(img) |
193 | tensor_image = self.transform(image) | 221 | tensor_image = self.transform(image) |
194 | #return img, targets | 222 | #return img, targets |
195 | - return tensor_image | 223 | + return tensor_image, targets |
196 | 224 | ||
197 | def get_dataset(args, transform, split='train'): | 225 | def get_dataset(args, transform, split='train'): |
198 | assert split in ['train', 'val', 'test', 'trainval'] | 226 | assert split in ['train', 'val', 'test', 'trainval'] |
... | @@ -224,9 +252,9 @@ def get_dataset(args, transform, split='train'): | ... | @@ -224,9 +252,9 @@ def get_dataset(args, transform, split='train'): |
224 | 252 | ||
225 | elif args.dataset == 'BraTS': | 253 | elif args.dataset == 'BraTS': |
226 | if split in ['train']: | 254 | if split in ['train']: |
227 | - dataset = CustomDataset(TRAIN_DATASET_PATH, TRAIN_TARGET_PATH, transform=transform) | 255 | + dataset = CustomDataset(TRAIN_DATASET_PATH, transform=transform) |
228 | else: | 256 | else: |
229 | - dataset = CustomDataset(VAL_DATASET_PATH, VAL_TARGET_PATH, transform=transform) | 257 | + dataset = CustomDataset(VAL_DATASET_PATH, transform=transform) |
230 | 258 | ||
231 | 259 | ||
232 | else: | 260 | else: |
... | @@ -250,6 +278,7 @@ def get_inf_dataloader(args, dataset): | ... | @@ -250,6 +278,7 @@ def get_inf_dataloader(args, dataset): |
250 | 278 | ||
251 | while True: | 279 | while True: |
252 | try: | 280 | try: |
281 | + #print("batch=dataloader:\n", batch, '\n') | ||
253 | batch = next(data_loader) | 282 | batch = next(data_loader) |
254 | 283 | ||
255 | except StopIteration: | 284 | except StopIteration: |
... | @@ -334,8 +363,7 @@ def get_valid_transform(args, model): | ... | @@ -334,8 +363,7 @@ def get_valid_transform(args, model): |
334 | 363 | ||
335 | def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer, device=None): | 364 | def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer, device=None): |
336 | model.train() | 365 | model.train() |
337 | - print('\nBatch\n', batch) | 366 | + #print('\nBatch\n', batch) |
338 | - print('\nBatch size\n', batch.size()) | ||
339 | images, target = batch | 367 | images, target = batch |
340 | 368 | ||
341 | if device: | 369 | if device: | ... | ... |
-
Please register or login to post a comment