train.py
10.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
from statistics import mean
from math import isfinite
import torch
from torch.optim import SGD, AdamW
from torch.optim.lr_scheduler import LambdaLR, SAVE_STATE_WARNING
from apex import amp, optimizers
from apex.parallel import DistributedDataParallel as ADDP
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import GradScaler, autocast
from .backbones.layers import convert_fixedbn_model
from .data import DataIterator, RotatedDataIterator
from .dali import DaliDataIterator
from .utils import ignore_sigint, post_metrics, Profiler
from .infer import infer
import warnings
warnings.filterwarnings('ignore', message=SAVE_STATE_WARNING, category=UserWarning)
def train(model, state, path, annotations, val_path, val_annotations, resize, max_size, jitter, batch_size, iterations,
val_iterations, mixed_precision, lr, warmup, milestones, gamma, rank=0, world=1, no_apex=False, use_dali=True,
verbose=True, metrics_url=None, logdir=None, rotate_augment=False, augment_brightness=0.0,
augment_contrast=0.0, augment_hue=0.0, augment_saturation=0.0, regularization_l2=0.0001, rotated_bbox=False,
absolute_angle=False):
'Train the model on the given dataset'
# Prepare model
nn_model = model
stride = model.stride
model = convert_fixedbn_model(model)
if torch.cuda.is_available():
model = model.to(memory_format=torch.channels_last).cuda()
# Setup optimizer and schedule
optimizer = SGD(model.parameters(), lr=lr, weight_decay=regularization_l2, momentum=0.9)
is_master = rank==0
if not no_apex:
loss_scale = "dynamic" if use_dali else "128.0"
model, optimizer = amp.initialize(model, optimizer,
opt_level='O2' if mixed_precision else 'O0',
keep_batchnorm_fp32=True,
loss_scale=loss_scale,
verbosity=is_master)
if world > 1:
model = DDP(model, device_ids=[rank]) if no_apex else ADDP(model)
model.train()
if 'optimizer' in state:
optimizer.load_state_dict(state['optimizer'])
def schedule(train_iter):
if warmup and train_iter <= warmup:
return 0.9 * train_iter / warmup + 0.1
return gamma ** len([m for m in milestones if m <= train_iter])
scheduler = LambdaLR(optimizer, schedule)
if 'scheduler' in state:
scheduler.load_state_dict(state['scheduler'])
# Prepare dataset
if verbose: print('Preparing dataset...')
if rotated_bbox:
if use_dali: raise NotImplementedError("This repo does not currently support DALI for rotated bbox detections.")
data_iterator = RotatedDataIterator(path, jitter, max_size, batch_size, stride,
world, annotations, training=True, rotate_augment=rotate_augment,
augment_brightness=augment_brightness,
augment_contrast=augment_contrast, augment_hue=augment_hue,
augment_saturation=augment_saturation, absolute_angle=absolute_angle)
else:
data_iterator = (DaliDataIterator if use_dali else DataIterator)(
path, jitter, max_size, batch_size, stride,
world, annotations, training=True, rotate_augment=rotate_augment, augment_brightness=augment_brightness,
augment_contrast=augment_contrast, augment_hue=augment_hue, augment_saturation=augment_saturation)
if verbose: print(data_iterator)
if verbose:
print(' device: {} {}'.format(
world, 'cpu' if not torch.cuda.is_available() else 'GPU' if world == 1 else 'GPUs'))
print(' batch: {}, precision: {}'.format(batch_size, 'mixed' if mixed_precision else 'full'))
print(' BBOX type:', 'rotated' if rotated_bbox else 'axis aligned')
print('Training model for {} iterations...'.format(iterations))
# Create TensorBoard writer
if is_master and logdir is not None:
from torch.utils.tensorboard import SummaryWriter
if verbose:
print('Writing TensorBoard logs to: {}'.format(logdir))
writer = SummaryWriter(log_dir=logdir)
scaler = GradScaler()
profiler = Profiler(['train', 'fw', 'bw'])
iteration = state.get('iteration', 0)
while iteration < iterations:
cls_losses, box_losses = [], []
for i, (data, target) in enumerate(data_iterator):
if iteration>=iterations:
break
# Forward pass
profiler.start('fw')
optimizer.zero_grad()
if not no_apex:
cls_loss, box_loss = model([data.contiguous(memory_format=torch.channels_last), target])
else:
with autocast():
cls_loss, box_loss = model([data.contiguous(memory_format=torch.channels_last), target])
del data
profiler.stop('fw')
# Backward pass
profiler.start('bw')
if not no_apex:
with amp.scale_loss(cls_loss + box_loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
else:
scaler.scale(cls_loss + box_loss).backward()
scaler.step(optimizer)
scaler.update()
scheduler.step()
# Reduce all losses
cls_loss, box_loss = cls_loss.mean().clone(), box_loss.mean().clone()
if world > 1:
torch.distributed.all_reduce(cls_loss)
torch.distributed.all_reduce(box_loss)
cls_loss /= world
box_loss /= world
if is_master:
cls_losses.append(cls_loss)
box_losses.append(box_loss)
if is_master and not isfinite(cls_loss + box_loss):
raise RuntimeError('Loss is diverging!\n{}'.format(
'Try lowering the learning rate.'))
del cls_loss, box_loss
profiler.stop('bw')
iteration += 1
profiler.bump('train')
if is_master and (profiler.totals['train'] > 60 or iteration == iterations):
focal_loss = torch.stack(list(cls_losses)).mean().item()
box_loss = torch.stack(list(box_losses)).mean().item()
learning_rate = optimizer.param_groups[0]['lr']
if verbose:
msg = '[{:{len}}/{}]'.format(iteration, iterations, len=len(str(iterations)))
msg += ' focal loss: {:.3f}'.format(focal_loss)
msg += ', box loss: {:.3f}'.format(box_loss)
msg += ', {:.3f}s/{}-batch'.format(profiler.means['train'], batch_size)
msg += ' (fw: {:.3f}s, bw: {:.3f}s)'.format(profiler.means['fw'], profiler.means['bw'])
msg += ', {:.1f} im/s'.format(batch_size / profiler.means['train'])
msg += ', lr: {:.2g}'.format(learning_rate)
print(msg, flush=True)
if is_master and logdir is not None:
writer.add_scalar('focal_loss', focal_loss, iteration)
writer.add_scalar('box_loss', box_loss, iteration)
writer.add_scalar('learning_rate', learning_rate, iteration)
del box_loss, focal_loss
if metrics_url:
post_metrics(metrics_url, {
'focal loss': mean(cls_losses),
'box loss': mean(box_losses),
'im_s': batch_size / profiler.means['train'],
'lr': learning_rate
})
# Save model weights
state.update({
'iteration': iteration,
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
})
with ignore_sigint():
nn_model.save(state)
profiler.reset()
del cls_losses[:], box_losses[:]
if val_annotations and (iteration == iterations or iteration % val_iterations == 0):
stats = infer(model, val_path, None, resize, max_size, batch_size, annotations=val_annotations,
mixed_precision=mixed_precision, is_master=is_master, world=world, use_dali=use_dali,
no_apex=no_apex, is_validation=True, verbose=False, rotated_bbox=rotated_bbox)
model.train()
if is_master and logdir is not None and stats is not None:
writer.add_scalar(
'Validation_Precision/mAP', stats[0], iteration)
writer.add_scalar(
'Validation_Precision/mAP@0.50IoU', stats[1], iteration)
writer.add_scalar(
'Validation_Precision/mAP@0.75IoU', stats[2], iteration)
writer.add_scalar(
'Validation_Precision/mAP (small)', stats[3], iteration)
writer.add_scalar(
'Validation_Precision/mAP (medium)', stats[4], iteration)
writer.add_scalar(
'Validation_Precision/mAP (large)', stats[5], iteration)
writer.add_scalar(
'Validation_Recall/mAR (max 1 Dets)', stats[6], iteration)
writer.add_scalar(
'Validation_Recall/mAR (max 10 Dets)', stats[7], iteration)
writer.add_scalar(
'Validation_Recall/mAR (max 100 Dets)', stats[8], iteration)
writer.add_scalar(
'Validation_Recall/mAR (small)', stats[9], iteration)
writer.add_scalar(
'Validation_Recall/mAR (medium)', stats[10], iteration)
writer.add_scalar(
'Validation_Recall/mAR (large)', stats[11], iteration)
if (iteration==iterations and not rotated_bbox) or (iteration>iterations and rotated_bbox):
break
if is_master and logdir is not None:
writer.close()