roitrain.py
3.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from __future__ import division
from roipool2 import *
from models import *
from utils.utils import *
from utils.datasets import *
from utils.parse_config import *
# from test import evaluate
from terminaltables import AsciiTable
import os
import sys
import time
import datetime
import argparse
import warnings
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torch.autograd import Variable
import torch.optim as optim
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == '__main__':
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('device: ', device)
data_config = parse_data_config('config/cafe_distance.data')
train_path = data_config["train"]
valid_path = data_config["valid"]
class_names = load_classes(data_config["names"])
model = Darknet('config/yolov3-tiny.cfg', 416).to(device)
model.load_state_dict(torch.load('checkpoints_cafe_distance/tiny1_2500.pth', map_location=device))
model.eval()
dataset = ListDataset(train_path, augment=True, multiscale=True)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=1,
shuffle=True,
num_workers=4,
pin_memory=True,
collate_fn=dataset.collate_fn,
)
model_distance = ROIPool((3, 3)).to(device)
model_parameters = filter(lambda p: p.requires_grad, model_distance.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print('Params: ', params)
optimizer = torch.optim.Adam(model_distance.parameters())
a = []
for epoch in range(2000):
warnings.filterwarnings('ignore', category=UserWarning)
for batch_i, (img_path, imgs, targets, targets_distance) in enumerate(dataloader):
imgs = Variable(imgs.to(device))
with torch.no_grad():
featuremap, detections = model(imgs)
# print(featuremap.shape)
featuremap = Variable(featuremap.to(device))
detections = non_max_suppression(detections, 0.8, 0.4)
targets_distance = torch.tensor(targets_distance[0])
targets_distance = Variable(targets_distance, requires_grad=True)
if detections is not None:
detections[0] = Variable(detections[0], requires_grad=True)
loss, outputs = model_distance(featuremap, detections[0], targets=targets_distance)
# loss = torch.tensor([loss]).to(device)
# loss.requires_grad = True
# print(model_distance.fc1.bias)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# print(model_distance.fc1.bias)
# print(batch_i)
print(epoch)
# print(featuremap)
if epoch % 10 == 0:
optimizer.param_groups[0]['lr'] /= 2
if epoch % 10 == 0:
torch.save(model_distance.state_dict(), f'checkpoints_distance11/tiny1_{epoch}.pth')