dataset.py 1.5 KB
import logging

from box import Box
from torchvision.transforms import transforms

from src.common.data.ukbb_brain_age import UKBBBrainAGE
from src.common.data_utils import frame_drop

logger = logging.getLogger()


def get_dataset(name, test_csv=None, train_csv=None, valid_csv=None, root_path=None,
                train_num_sample=-1, frame_keep_style="random", frame_keep_fraction=1.0,
                frame_dim=1, impute=False, **kwargs):
    """ return dataset """

    if name == "brain_age":
        # Transformations to remove frames
        frame_drop_transform = lambda x: frame_drop(x, frame_keep_style=frame_keep_style,
                                                    frame_keep_fraction=frame_keep_fraction,
                                                    frame_dim=frame_dim, impute=impute)
        # Transformation to add noise to frames
        transform = transforms.Compose([frame_drop_transform])
        train_data = UKBBBrainAGE(root=root_path, metadatafile=train_csv,
                                  num_sample=train_num_sample, transform=transform)
        test_data = UKBBBrainAGE(root=root_path, metadatafile=test_csv, transform=transform)
        valid_data = UKBBBrainAGE(root=root_path, metadatafile=valid_csv if valid_csv else test_csv,
                                  transform=transform)
        return Box({"train": train_data, "test": test_data, "valid": valid_data}), {}

    logger.error(f"Invalid data name {name} specified")
    raise Exception(f"Invalid data name {name} specified")