dataFromH5.py 496 Bytes
import torch.utils.data as data
import torch
import h5py

class Read_dataset_h5(data.Dataset):
    def __init__(self, file_path):
        super(Read_dataset_h5, self).__init__()
        hf = h5py.File(file_path)
        self.input = hf.get('input')
        self.label = hf.get('label')

    def __getitem__(self, index):
        return torch.from_numpy(self.input[index,:,:,:]).float(), torch.from_numpy(self.label[index,:,:,:]).float()

    def __len__(self):
        return self.input.shape[0]