Showing
6 changed files
with
47 additions
and
8 deletions
| ... | @@ -5,6 +5,7 @@ import collections | ... | @@ -5,6 +5,7 @@ import collections |
| 5 | import pickle as cp | 5 | import pickle as cp |
| 6 | import glob | 6 | import glob |
| 7 | import numpy as np | 7 | import numpy as np |
| 8 | +import pandas as pd | ||
| 8 | 9 | ||
| 9 | import torch | 10 | import torch |
| 10 | import torchvision | 11 | import torchvision |
| ... | @@ -31,7 +32,7 @@ current_epoch = 0 | ... | @@ -31,7 +32,7 @@ current_epoch = 0 |
| 31 | def split_dataset(args, dataset, k): | 32 | def split_dataset(args, dataset, k): |
| 32 | # load dataset | 33 | # load dataset |
| 33 | X = list(range(len(dataset))) | 34 | X = list(range(len(dataset))) |
| 34 | - Y = dataset.targets | 35 | + Y = dataset |
| 35 | 36 | ||
| 36 | # split to k-fold | 37 | # split to k-fold |
| 37 | assert len(X) == len(Y) | 38 | assert len(X) == len(Y) |
| ... | @@ -162,9 +163,11 @@ class CustomDataset(Dataset): | ... | @@ -162,9 +163,11 @@ class CustomDataset(Dataset): |
| 162 | return self.len | 163 | return self.len |
| 163 | 164 | ||
| 164 | def __getitem__(self, idx): | 165 | def __getitem__(self, idx): |
| 165 | - if self.transforms is not None: | 166 | + img, targets = self.img[idx], self.targets[idx] |
| 166 | - img = self.transforms(img) | 167 | + |
| 167 | - return img | 168 | + if self.transform is not None: |
| 169 | + img = self.transform(img) | ||
| 170 | + return img, targets | ||
| 168 | 171 | ||
| 169 | def get_dataset(args, transform, split='train'): | 172 | def get_dataset(args, transform, split='train'): |
| 170 | assert split in ['train', 'val', 'test', 'trainval'] | 173 | assert split in ['train', 'val', 'test', 'trainval'] | ... | ... |
| ... | @@ -15,6 +15,9 @@ from torchvision.transforms import transforms | ... | @@ -15,6 +15,9 @@ from torchvision.transforms import transforms |
| 15 | from sklearn.model_selection import StratifiedShuffleSplit | 15 | from sklearn.model_selection import StratifiedShuffleSplit |
| 16 | from theconf import Config as C | 16 | from theconf import Config as C |
| 17 | 17 | ||
| 18 | + | ||
| 19 | +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | ||
| 20 | + | ||
| 18 | from FastAutoAugment.archive import arsaug_policy, autoaug_policy, autoaug_paper_cifar10, fa_reduced_cifar10, fa_reduced_svhn, fa_resnet50_rimagenet | 21 | from FastAutoAugment.archive import arsaug_policy, autoaug_policy, autoaug_paper_cifar10, fa_reduced_cifar10, fa_reduced_svhn, fa_resnet50_rimagenet |
| 19 | from FastAutoAugment.augmentations import * | 22 | from FastAutoAugment.augmentations import * |
| 20 | from FastAutoAugment.common import get_logger | 23 | from FastAutoAugment.common import get_logger |
| ... | @@ -79,6 +82,29 @@ def get_dataloaders(dataset, batch, dataroot, split=0.15, split_idx=0, multinode | ... | @@ -79,6 +82,29 @@ def get_dataloaders(dataset, batch, dataroot, split=0.15, split_idx=0, multinode |
| 79 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | 82 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| 80 | ]) | 83 | ]) |
| 81 | 84 | ||
| 85 | + elif 'BraTS' in dataset: | ||
| 86 | + input_size = 240 | ||
| 87 | + sized_size = 256 | ||
| 88 | + | ||
| 89 | + if 'efficientnet' in C.get()['model']['type']: | ||
| 90 | + input_size = EfficientNet.get_image_size(C.get()['model']['type']) | ||
| 91 | + sized_size = input_size + 16 # TODO | ||
| 92 | + | ||
| 93 | + logger.info('size changed to %d/%d.' % (input_size, sized_size)) | ||
| 94 | + | ||
| 95 | + transform_train = transforms.Compose([ | ||
| 96 | + EfficientNetRandomCrop(input_size), | ||
| 97 | + transforms.Resize((input_size, input_size), interpolation=Image.BICUBIC), | ||
| 98 | + transforms.RandomHorizontalFlip(), | ||
| 99 | + transforms.ToTensor(), | ||
| 100 | + ]) | ||
| 101 | + | ||
| 102 | + transform_test = transforms.Compose([ | ||
| 103 | + EfficientNetCenterCrop(input_size), | ||
| 104 | + transforms.Resize((input_size, input_size), interpolation=Image.BICUBIC), | ||
| 105 | + transforms.ToTensor(), | ||
| 106 | + ]) | ||
| 107 | + | ||
| 82 | else: | 108 | else: |
| 83 | raise ValueError('dataset=%s' % dataset) | 109 | raise ValueError('dataset=%s' % dataset) |
| 84 | 110 | ||
| ... | @@ -111,7 +137,10 @@ def get_dataloaders(dataset, batch, dataroot, split=0.15, split_idx=0, multinode | ... | @@ -111,7 +137,10 @@ def get_dataloaders(dataset, batch, dataroot, split=0.15, split_idx=0, multinode |
| 111 | if C.get()['cutout'] > 0: | 137 | if C.get()['cutout'] > 0: |
| 112 | transform_train.transforms.append(CutoutDefault(C.get()['cutout'])) | 138 | transform_train.transforms.append(CutoutDefault(C.get()['cutout'])) |
| 113 | 139 | ||
| 114 | - if dataset == 'cifar10': | 140 | + if dataset == 'BraTS': |
| 141 | + total_trainset = | ||
| 142 | + testset = | ||
| 143 | + elif dataset == 'cifar10': | ||
| 115 | total_trainset = torchvision.datasets.CIFAR10(root=dataroot, train=True, download=True, transform=transform_train) | 144 | total_trainset = torchvision.datasets.CIFAR10(root=dataroot, train=True, download=True, transform=transform_train) |
| 116 | testset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=True, transform=transform_test) | 145 | testset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=True, transform=transform_test) |
| 117 | elif dataset == 'reduced_cifar10': | 146 | elif dataset == 'reduced_cifar10': | ... | ... |
| ... | @@ -16,6 +16,9 @@ from ray.tune.suggest import HyperOptSearch | ... | @@ -16,6 +16,9 @@ from ray.tune.suggest import HyperOptSearch |
| 16 | from ray.tune import register_trainable, run_experiments | 16 | from ray.tune import register_trainable, run_experiments |
| 17 | from tqdm import tqdm | 17 | from tqdm import tqdm |
| 18 | 18 | ||
| 19 | + | ||
| 20 | +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | ||
| 21 | + | ||
| 19 | from FastAutoAugment.archive import remove_deplicates, policy_decoder | 22 | from FastAutoAugment.archive import remove_deplicates, policy_decoder |
| 20 | from FastAutoAugment.augmentations import augment_list | 23 | from FastAutoAugment.augmentations import augment_list |
| 21 | from FastAutoAugment.common import get_logger, add_filehandler | 24 | from FastAutoAugment.common import get_logger, add_filehandler | ... | ... |
| ... | @@ -19,6 +19,9 @@ import torch.distributed as dist | ... | @@ -19,6 +19,9 @@ import torch.distributed as dist |
| 19 | from tqdm import tqdm | 19 | from tqdm import tqdm |
| 20 | from theconf import Config as C, ConfigArgumentParser | 20 | from theconf import Config as C, ConfigArgumentParser |
| 21 | 21 | ||
| 22 | + | ||
| 23 | +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | ||
| 24 | + | ||
| 22 | from FastAutoAugment.common import get_logger, EMA, add_filehandler | 25 | from FastAutoAugment.common import get_logger, EMA, add_filehandler |
| 23 | from FastAutoAugment.data import get_dataloaders | 26 | from FastAutoAugment.data import get_dataloaders |
| 24 | from FastAutoAugment.lr_scheduler import adjust_learning_rate_resnet | 27 | from FastAutoAugment.lr_scheduler import adjust_learning_rate_resnet | ... | ... |
| ... | @@ -32,6 +32,7 @@ for i = 1 : length(subFolders) | ... | @@ -32,6 +32,7 @@ for i = 1 : length(subFolders) |
| 32 | 32 | ||
| 33 | % copy flair, segment flair data | 33 | % copy flair, segment flair data |
| 34 | 34 | ||
| 35 | + % seg의 검은 부분(정보 x)과 같은 인덱스 = 0 | ||
| 35 | cp_flair(seg == 0) = 0; | 36 | cp_flair(seg == 0) = 0; |
| 36 | 37 | ||
| 37 | % save a segmented data | 38 | % save a segmented data | ... | ... |
| 1 | inputheader = '..\data\MICCAI_BraTS_2019_Data_Training\HGG_seg_flair\'; | 1 | inputheader = '..\data\MICCAI_BraTS_2019_Data_Training\HGG_seg_flair\'; |
| 2 | -outfolder = '..\data\MICCAI_BraTS_2019_Data_Training\frame\'; | 2 | +outfolder = '..\data\MICCAI_BraTS_2019_Data_Training\total_frame\'; |
| 3 | 3 | ||
| 4 | files = dir(inputheader); | 4 | files = dir(inputheader); |
| 5 | id = {files.name}; | 5 | id = {files.name}; |
| ... | @@ -38,14 +38,14 @@ for i = 1 : length(files) | ... | @@ -38,14 +38,14 @@ for i = 1 : length(files) |
| 38 | c = 0; | 38 | c = 0; |
| 39 | step = round(((en) - (st))/11); | 39 | step = round(((en) - (st))/11); |
| 40 | for k = st + step : step : st + step*10 | 40 | for k = st + step : step : st + step*10 |
| 41 | - c = c+ 1; | ||
| 42 | - | ||
| 43 | type = '.png'; | 41 | type = '.png'; |
| 44 | filename = strcat(id, '_', int2str(c), type); % BraTS19_2013_2_1_seg_flair_c.png | 42 | filename = strcat(id, '_', int2str(c), type); % BraTS19_2013_2_1_seg_flair_c.png |
| 45 | outpath = strcat(outfolder, filename); | 43 | outpath = strcat(outfolder, filename); |
| 46 | % typecase int16 to double, range[0, 1], rotate 90 and filp updown | 44 | % typecase int16 to double, range[0, 1], rotate 90 and filp updown |
| 47 | cp_data = flipud(rot90(mat2gray(double(data(:,:,k))))); | 45 | cp_data = flipud(rot90(mat2gray(double(data(:,:,k))))); |
| 48 | imwrite(cp_data, outpath); | 46 | imwrite(cp_data, outpath); |
| 47 | + | ||
| 48 | + c = c+ 1; | ||
| 49 | end | 49 | end |
| 50 | 50 | ||
| 51 | end | 51 | end | ... | ... |
-
Please register or login to post a comment