align.py
5.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
from functools import partial
from multiprocessing import Pool
import os
import re
import cropper
import numpy as np
import tqdm
# ==============================================================================
# = param =
# ==============================================================================
parser = argparse.ArgumentParser()
# main
parser.add_argument('--img_dir', dest='img_dir', default='./data/img_celeba')
parser.add_argument('--save_dir', dest='save_dir', default='./data/aligned')
parser.add_argument('--landmark_file', dest='landmark_file', default='./data/landmark.txt')
parser.add_argument('--standard_landmark_file', dest='standard_landmark_file', default='./data/standard_landmark_68pts.txt')
parser.add_argument('--crop_size_h', dest='crop_size_h', type=int, default=572)
parser.add_argument('--crop_size_w', dest='crop_size_w', type=int, default=572)
parser.add_argument('--move_h', dest='move_h', type=float, default=0.25)
parser.add_argument('--move_w', dest='move_w', type=float, default=0.)
parser.add_argument('--save_format', dest='save_format', choices=['jpg', 'png'], default='jpg')
parser.add_argument('--n_worker', dest='n_worker', type=int, default=8)
# others
parser.add_argument('--face_factor', dest='face_factor', type=float, help='The factor of face area relative to the output image.', default=0.45)
parser.add_argument('--align_type', dest='align_type', choices=['affine', 'similarity'], default='similarity')
parser.add_argument('--order', dest='order', type=int, choices=[0, 1, 2, 3, 4, 5], help='The order of interpolation.', default=3)
parser.add_argument('--mode', dest='mode', choices=['constant', 'edge', 'symmetric', 'reflect', 'wrap'], default='edge')
args = parser.parse_args()
# ==============================================================================
# = opencv first =
# ==============================================================================
_DEAFAULT_JPG_QUALITY = 95
try:
import cv2
imread = cv2.imread
imwrite = partial(cv2.imwrite, params=[int(cv2.IMWRITE_JPEG_QUALITY), _DEAFAULT_JPG_QUALITY])
align_crop = cropper.align_crop_opencv
print('Use OpenCV')
except:
import skimage.io as io
imread = io.imread
imwrite = partial(io.imsave, quality=_DEAFAULT_JPG_QUALITY)
align_crop = cropper.align_crop_skimage
print('Importing OpenCv fails. Use scikit-image')
# ==============================================================================
# = run =
# ==============================================================================
# count landmarks
with open(args.landmark_file) as f:
line = f.readline()
n_landmark = len(re.split('[ ]+', line)[1:]) // 2
# load standard landmark
standard_landmark = np.genfromtxt(args.standard_landmark_file, dtype=np.float).reshape(n_landmark, 2)
standard_landmark[:, 0] += args.move_w
standard_landmark[:, 1] += args.move_h
# data dir
save_dir = os.path.join(args.save_dir, 'align_size(%d,%d)_move(%.3f,%.3f)_face_factor(%.3f)_%s' % (args.crop_size_h, args.crop_size_w, args.move_h, args.move_w, args.face_factor, args.save_format))
data_dir = os.path.join(save_dir, 'data')
if not os.path.isdir(data_dir):
os.makedirs(data_dir)
def work(name, landmark) -> str: # a single work
for _ in range(3): # try three times
try:
img = imread(os.path.join(args.img_dir, name))
img_crop, tformed_landmarks = align_crop(img,
landmark,
standard_landmark,
crop_size=(args.crop_size_h, args.crop_size_w),
face_factor=args.face_factor,
align_type=args.align_type,
order=args.order,
mode=args.mode)
name = os.path.splitext(name)[0] + '.' + args.save_format
path = os.path.join(data_dir, name)
if not os.path.isdir(os.path.split(path)[0]):
os.makedirs(os.path.split(path)[0])
imwrite(path, img_crop)
tformed_landmarks.shape = -1
name_landmark_str = ('%s' + ' %.1f' * n_landmark * 2) % ((name, ) + tuple(tformed_landmarks))
return name_landmark_str
except:
print('%s fails!' % name)
if __name__ == "__main__":
img_names = np.genfromtxt(args.landmark_file, dtype=np.str, usecols=0)
landmarks = np.genfromtxt(args.landmark_file, dtype=np.float,
usecols=range(1, n_landmark * 2 + 1)).reshape(-1, n_landmark, 2)
n_pics = len(img_names)
landmarks_path = os.path.join(save_dir, 'landmark.txt')
f = open(landmarks_path, 'w')
pool = Pool(args.n_worker)
bar = tqdm.tqdm(total=n_pics)
tasks = []
for i in range(n_pics):
tasks.append(pool.apply_async(work, (img_names[i], landmarks[i]), callback=lambda _: bar.update()))
try:
result = tasks.pop(0).get()
if result is not None and result != "":
f.write(result + '\n')
except:
pass
pool.close()
pool.join()
bar.close()
f.close()