조현아

rm stratified FAA getBraTS_3

......@@ -3,4 +3,5 @@ tb-nightly
torchvision
torch
hyperopt
fire
pillow==6.2.1
natsort
\ No newline at end of file
......
......@@ -6,7 +6,8 @@ import pickle as cp
import glob
import numpy as np
import pandas as pd
from natsort import natsorted
from PIL import Image
import torch
import torchvision
import torch.nn.functional as F
......@@ -16,12 +17,14 @@ from torch.utils.data import Subset
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from networks import basenet
TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/'
VAL_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame/'
TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame'
VAL_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame'
TRAIN_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame.csv'
VAL_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame.csv'
......@@ -32,16 +35,31 @@ current_epoch = 0
def split_dataset(args, dataset, k):
# load dataset
X = list(range(len(dataset)))
Y = dataset
#Y = dataset.targets
# split to k-fold
assert len(X) == len(Y)
# assert len(X) == len(Y)
def _it_to_list(_it):
return list(zip(*list(_it)))
sss = StratifiedShuffleSplit(n_splits=k, random_state=args.seed, test_size=0.1)
Dm_indexes, Da_indexes = _it_to_list(sss.split(X, Y))
# sss = StratifiedShuffleSplit(n_splits=k, random_state=args.seed, test_size=0.1)
# Dm_indexes, Da_indexes = _it_to_list(sss.split(X, Y))
x_train = []
x_test = []
for i in range(k):
xtr, xte = train_test_split(X, random_state=args.seed, test_size=0.1)
x_train.append(xtr)
x_test.append(xte)
#kf = KFold(n_splits=k, random_state=args.seed, test)
#kf.split(x_train)
Dm_indexes, Da_indexes = np.array(x_train), np.array(x_test)
return Dm_indexes, Da_indexes
......@@ -154,20 +172,27 @@ class CustomDataset(Dataset):
def __init__(self, path, target_path, transform = None):
self.path = path
self.transform = transform
#self.img = np.load(path)
self.img = glob.glob(path + '/*.png')
self.len = len(self.img)
#self.imgpath = glob.glob(path + '/*.png'
#self.img = np.expand_dims(np.load(glob.glob(path + '/*.png'), axis = 3)
self.imgs = natsorted(os.listdir(path))
self.len = len(self.imgs)
#self.len = self.img.shape[0]
self.targets = pd.read_csv(target_path, header = None)
def __len__(self):
return self.len
def __getitem__(self, idx):
img, targets = self.img[idx], self.targets[idx]
#img, targets = self.img[idx], self.targets[idx]
img_loc = os.path.join(self.path, self.imgs[idx])
#img = self.img[idx]
image = Image.open(img_loc)
if self.transform is not None:
img = self.transform(img)
return img, targets
#img = self.transform(img)
tensor_image = self.transform(image)
#return img, targets
return tensor_image
def get_dataset(args, transform, split='train'):
assert split in ['train', 'val', 'test', 'trainval']
......@@ -309,6 +334,8 @@ def get_valid_transform(args, model):
def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer, device=None):
model.train()
print('\nBatch\n', batch)
print('\nBatch size\n', batch.size())
images, target = batch
if device:
......