Hyunji

ukbb brain age

1 +import logging
2 +import os
3 +
4 +import nibabel
5 +import numpy
6 +import pandas
7 +from torchvision.datasets import VisionDataset
8 +
9 +logger = logging.getLogger()
10 +FILEPATHKEY = "9dof_2mm_vol"
11 +
12 +
13 +class UKBBBrainAGE(VisionDataset):
14 + @staticmethod
15 + def get_path(root, path):
16 + if path == "/" or root is None:
17 + return path
18 + return os.path.join(root, str(path))
19 +
20 + def __init__(self, root, metadatafile, transform=None, target_transform=None, verify=False,
21 + num_sample=-1, random_state=0):
22 + super().__init__(root, transform=transform, target_transform=target_transform)
23 + self.df = pandas.read_csv(metadatafile)
24 +
25 + # do a random sample of dataset
26 + if num_sample > 0:
27 + # fixed seed will be useful to train multiple models with same data
28 + self.df = self.df.sample(n=num_sample, random_state=random_state, replace=True)
29 +
30 + if verify:
31 + # remove all those entries for which we dont have file
32 + indices = []
33 + for i, row in self.df.iterrows():
34 + if not os.path.exists(self.get_path(root, row[FILEPATHKEY])):
35 + indices.append(i)
36 + if indices:
37 + logger.info(f"Dropping {len(indices)}")
38 + logger.debug(f"Dropped rows {indices}")
39 + self.df = self.df.drop(index=indices)
40 +
41 + def __getitem__(self, index):
42 + row = self.df.iloc[index]
43 + path = self.get_path(self.root, row[FILEPATHKEY])
44 + subject_id = row["subject_id"]
45 + age = row["age_at_scan"]
46 + img = nibabel.load(path).get_fdata()
47 + img = (img - img.mean()) / img.std()
48 + scan = img[numpy.newaxis, :, :, :]
49 + age = age
50 +
51 + if self.transform:
52 + scan = self.transform(scan)
53 +
54 + if self.target_transform:
55 + age = self.target_transform(age)
56 +
57 + return numpy.float32(scan), numpy.float32(age), subject_id
58 +
59 + def __len__(self):
60 + return self.df.shape[0]
...\ No newline at end of file ...\ No newline at end of file