dataloader.py
2.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Nov 8 02:47:50 2020
@author: dhk1349
input_data: each sample size is 300*50
from this loader,
will extract particular class actions
will extract 64 frames
and save to numpy file(.npy)
"""
#MAKING DATASET2
import numpy as np
def load_dataset(path="./ntu/xsub/"):
train_data=np.load(path+"xsub_train_300_50.npy")
test_data=np.load(path+"xsub_val_300_50.npy")
train_label=np.load(path+"train_label.pkl.npy", allow_pickle=True)
test_label=np.load(path+"val_label.pkl.npy", allow_pickle=True)
train_label=train_label[1]
test_label=test_label[1]
train_data=train_data.transpose(0, 2, 1, 3)
test_data=test_data.transpose(0, 2, 1, 3)
print("input data size: ", train_data.shape, test_data.shape)
dataset=[]
extractidx_=[i for i in range(0,300,4)]
exclude=np.random.choice(75,11)
extractidx=[]
for i in range(75):
if i not in exclude:
extractidx.append(extractidx_[i])
print("extract idx len: ", len(extractidx))
train_data_idx=[]
test_data_idx=[]
for idx, i in enumerate(train_label):
#if i==23: #kicking sth
if i==15:
train_data_idx.append(idx)
for idx, i in enumerate(test_label):
#if i==23: #kicking sth
if i==15:
test_data_idx.append(idx)
for idx,i in enumerate(train_data_idx):
sample=[]
for j in extractidx:
sample.append(train_data[i][j])
dataset.append(np.array(sample))
for idx,i in enumerate(test_data_idx):
sample=[]
for j in extractidx:
sample.append(test_data[i][j])
dataset.append(np.array(sample))
dataset=np.array(dataset)
print(dataset.shape)
dataset=dataset.transpose(0, 2, 1, 3)
print(dataset.shape)
#print(dataset[0])
np.save("/home/dhk1349/Desktop/Github/Deep-Learning/Pytorch/Action Generation/Generating other action classes/run on the spot/data_run on the spot/"+"Integrated_dataset_64_50", dataset)
return dataset
if __name__=="__main__":
load_dataset("/home/dhk1349/Desktop/Capstone Design2/ntu/xsub/")