dataset.py
4.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
import torch
import torch.nn as nn
import torch.utils.data as data
import cv2
import numpy as np
from tqdm import tqdm
class Places365Dataset(data.Dataset):
def __init__(self, cfg):
self.root_dir = cfg.root_dir
self.cfg = cfg
self.load_images()
def load_images(self):
self.fns =[]
idx = 0
img_paths = os.listdir(self.root_dir)
for cls_id in img_paths:
paths = os.path.join(self.root_dir, cls_id)
file_paths = os.listdir(paths)
for img_name in file_paths:
filename = os.path.join(paths, img_name)
self.fns.append(filename)
def __getitem__(self, index):
img_path = self.fns[index]
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (self.cfg.img_size, self.cfg.img_size))
mask = self.random_ff_mask(
shape = self.cfg.img_size,
max_angle = self.cfg.max_angle,
max_len = self.cfg.max_len,
max_width = self.cfg.max_width,
times = self.cfg.times)
img = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1).contiguous()
mask = torch.from_numpy(mask.astype(np.float32)).contiguous()
return img, mask
def collate_fn(self, batch):
imgs = torch.stack([i[0] for i in batch])
masks = torch.stack([i[1] for i in batch])
return {
'imgs': imgs,
'masks': masks
}
def __len__(self):
return len(self.fns)
def random_ff_mask(self, shape = 256 , max_angle = 4, max_len = 50, max_width = 20, times = 15):
"""Generate a random free form mask with configuration.
Args:
config: Config should have configuration including IMG_SHAPES,
VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH.
Returns:
tuple: (top, left, height, width)
"""
height = shape
width = shape
mask = np.zeros((height, width), np.float32)
times = np.random.randint(10, times)
for i in range(times):
start_x = np.random.randint(width)
start_y = np.random.randint(height)
for j in range(1 + np.random.randint(5)):
angle = 0.01 + np.random.randint(max_angle)
if i % 2 == 0:
angle = 2 * 3.1415926 - angle
length = 10 + np.random.randint(max_len)
brush_w = 5 + np.random.randint(max_width)
end_x = (start_x + length * np.sin(angle)).astype(np.int32)
end_y = (start_y + length * np.cos(angle)).astype(np.int32)
cv2.line(mask, (start_y, start_x), (end_y, end_x), 1.0, brush_w)
start_x, start_y = end_x, end_y
return mask.reshape((1, ) + mask.shape).astype(np.float32)
class FacemaskDataset(data.Dataset):
def __init__(self, cfg):
self.root_dir = cfg.root_dir
self.cfg = cfg
self.mask_folder = os.path.join(self.root_dir, 'celeba512_30k_binary')
self.img_folder = os.path.join(self.root_dir, 'celeba512_30k')
self.load_images()
def load_images(self):
self.fns = []
idx = 0
img_paths = sorted(os.listdir(self.img_folder))
for img_name in img_paths:
mask_name = img_name.split('.')[0]+'_binary.jpg'
img_path = os.path.join(self.img_folder, img_name)
mask_path = os.path.join(self.mask_folder, mask_name)
if os.path.isfile(mask_path):
self.fns.append([img_path, mask_path])
def __getitem__(self, index):
img_path, mask_path = self.fns[index]
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (self.cfg.img_size, self.cfg.img_size))
mask = cv2.imread(mask_path, 0)
mask[mask>0]=1.0
mask = np.expand_dims(mask, axis=0)
img = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1).contiguous()
mask = torch.from_numpy(mask.astype(np.float32)).contiguous()
return img, mask
def collate_fn(self, batch):
imgs = torch.stack([i[0] for i in batch])
masks = torch.stack([i[1] for i in batch])
return {
'imgs': imgs,
'masks': masks
}
def __len__(self):
return len(self.fns)