Showing
1 changed file
with
32 additions
and
0 deletions
2DCNN/src/common/dataset.py
0 → 100644
1 | +import logging | ||
2 | + | ||
3 | +from box import Box | ||
4 | +from torchvision.transforms import transforms | ||
5 | + | ||
6 | +from src.common.data.ukbb_brain_age import UKBBBrainAGE | ||
7 | +from src.common.data_utils import frame_drop | ||
8 | + | ||
9 | +logger = logging.getLogger() | ||
10 | + | ||
11 | + | ||
12 | +def get_dataset(name, test_csv=None, train_csv=None, valid_csv=None, root_path=None, | ||
13 | + train_num_sample=-1, frame_keep_style="random", frame_keep_fraction=1.0, | ||
14 | + frame_dim=1, impute=False, **kwargs): | ||
15 | + """ return dataset """ | ||
16 | + | ||
17 | + if name == "brain_age": | ||
18 | + # Transformations to remove frames | ||
19 | + frame_drop_transform = lambda x: frame_drop(x, frame_keep_style=frame_keep_style, | ||
20 | + frame_keep_fraction=frame_keep_fraction, | ||
21 | + frame_dim=frame_dim, impute=impute) | ||
22 | + # Transformation to add noise to frames | ||
23 | + transform = transforms.Compose([frame_drop_transform]) | ||
24 | + train_data = UKBBBrainAGE(root=root_path, metadatafile=train_csv, | ||
25 | + num_sample=train_num_sample, transform=transform) | ||
26 | + test_data = UKBBBrainAGE(root=root_path, metadatafile=test_csv, transform=transform) | ||
27 | + valid_data = UKBBBrainAGE(root=root_path, metadatafile=valid_csv if valid_csv else test_csv, | ||
28 | + transform=transform) | ||
29 | + return Box({"train": train_data, "test": test_data, "valid": valid_data}), {} | ||
30 | + | ||
31 | + logger.error(f"Invalid data name {name} specified") | ||
32 | + raise Exception(f"Invalid data name {name} specified") |
-
Please register or login to post a comment