Hyunji

dataset

1 +import logging
2 +
3 +from box import Box
4 +from torchvision.transforms import transforms
5 +
6 +from src.common.data.ukbb_brain_age import UKBBBrainAGE
7 +from src.common.data_utils import frame_drop
8 +
9 +logger = logging.getLogger()
10 +
11 +
12 +def get_dataset(name, test_csv=None, train_csv=None, valid_csv=None, root_path=None,
13 + train_num_sample=-1, frame_keep_style="random", frame_keep_fraction=1.0,
14 + frame_dim=1, impute=False, **kwargs):
15 + """ return dataset """
16 +
17 + if name == "brain_age":
18 + # Transformations to remove frames
19 + frame_drop_transform = lambda x: frame_drop(x, frame_keep_style=frame_keep_style,
20 + frame_keep_fraction=frame_keep_fraction,
21 + frame_dim=frame_dim, impute=impute)
22 + # Transformation to add noise to frames
23 + transform = transforms.Compose([frame_drop_transform])
24 + train_data = UKBBBrainAGE(root=root_path, metadatafile=train_csv,
25 + num_sample=train_num_sample, transform=transform)
26 + test_data = UKBBBrainAGE(root=root_path, metadatafile=test_csv, transform=transform)
27 + valid_data = UKBBBrainAGE(root=root_path, metadatafile=valid_csv if valid_csv else test_csv,
28 + transform=transform)
29 + return Box({"train": train_data, "test": test_data, "valid": valid_data}), {}
30 +
31 + logger.error(f"Invalid data name {name} specified")
32 + raise Exception(f"Invalid data name {name} specified")