Showing
1 changed file
with
60 additions
and
0 deletions
src/common/data/ukbb_brain_age.py
0 → 100644
1 | +import logging | ||
2 | +import os | ||
3 | + | ||
4 | +import nibabel | ||
5 | +import numpy | ||
6 | +import pandas | ||
7 | +from torchvision.datasets import VisionDataset | ||
8 | + | ||
9 | +logger = logging.getLogger() | ||
10 | +FILEPATHKEY = "9dof_2mm_vol" | ||
11 | + | ||
12 | + | ||
13 | +class UKBBBrainAGE(VisionDataset): | ||
14 | + @staticmethod | ||
15 | + def get_path(root, path): | ||
16 | + if path == "/" or root is None: | ||
17 | + return path | ||
18 | + return os.path.join(root, str(path)) | ||
19 | + | ||
20 | + def __init__(self, root, metadatafile, transform=None, target_transform=None, verify=False, | ||
21 | + num_sample=-1, random_state=0): | ||
22 | + super().__init__(root, transform=transform, target_transform=target_transform) | ||
23 | + self.df = pandas.read_csv(metadatafile) | ||
24 | + | ||
25 | + # do a random sample of dataset | ||
26 | + if num_sample > 0: | ||
27 | + # fixed seed will be useful to train multiple models with same data | ||
28 | + self.df = self.df.sample(n=num_sample, random_state=random_state, replace=True) | ||
29 | + | ||
30 | + if verify: | ||
31 | + # remove all those entries for which we dont have file | ||
32 | + indices = [] | ||
33 | + for i, row in self.df.iterrows(): | ||
34 | + if not os.path.exists(self.get_path(root, row[FILEPATHKEY])): | ||
35 | + indices.append(i) | ||
36 | + if indices: | ||
37 | + logger.info(f"Dropping {len(indices)}") | ||
38 | + logger.debug(f"Dropped rows {indices}") | ||
39 | + self.df = self.df.drop(index=indices) | ||
40 | + | ||
41 | + def __getitem__(self, index): | ||
42 | + row = self.df.iloc[index] | ||
43 | + path = self.get_path(self.root, row[FILEPATHKEY]) | ||
44 | + subject_id = row["subject_id"] | ||
45 | + age = row["age_at_scan"] | ||
46 | + img = nibabel.load(path).get_fdata() | ||
47 | + img = (img - img.mean()) / img.std() | ||
48 | + scan = img[numpy.newaxis, :, :, :] | ||
49 | + age = age | ||
50 | + | ||
51 | + if self.transform: | ||
52 | + scan = self.transform(scan) | ||
53 | + | ||
54 | + if self.target_transform: | ||
55 | + age = self.target_transform(age) | ||
56 | + | ||
57 | + return numpy.float32(scan), numpy.float32(age), subject_id | ||
58 | + | ||
59 | + def __len__(self): | ||
60 | + return self.df.shape[0] | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or login to post a comment