finetune.py 8.07 KB
import torch
import torch.nn as nn
import os
import shutil
import logging
from model import mobilenetv3
from utils import get_args_from_yaml
import torchvision.datasets as datasets
from utils import AverageMeter, accuracy, printlog, precision, recall
import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np
import time
from get_mean_std import get_params

model = mobilenetv3(n_class=7, blocknum=6, dropout=0.5)
model = model.train()
data_path = "../data/All"
check_path = "output/All/30114_model=MobilenetV3-ep=3000-block=6-class=8/model_best.pth.tar"
validation_ratio = 0.1
random_seed = 10
gpus=[0]
epochs = 3000
resize_size=128

logger = logging.getLogger()
logger.setLevel(logging.INFO)
streamHandler = logging.StreamHandler()
logger.addHandler(streamHandler)

fileHandler = logging.FileHandler("logs/finetune.log")
logger.addHandler(fileHandler)


def save_checkpoint(state, is_best, block =6, filename='checkpoint.pth.tar'):
    """Saves checkpoint to disk"""
    directory = "%s/%s/" % ('output', 'All')
    if not os.path.exists(directory):
        os.makedirs(directory)
    filename = directory + filename
    torch.save(state, filename)
    logger.info(f"Checkpoint Saved: {filename}")
    best_filename = f"output/All/model_best.pth.tar"
    if is_best:
        shutil.copyfile(filename, best_filename)
        logger.info(f"New Best Checkpoint saved: {best_filename}")

    return best_filename

def validate(val_loader, model, criterion, epoch, q=None):
    """Perform validaadd_model_to_queuetion on the validation set"""
    with torch.no_grad():
        batch_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
        prec = []
        rec = []

        for i in range(7):
            prec.append(AverageMeter())
            rec.append(AverageMeter())
        # switch to evaluate mode
        model.eval()
        end = time.time()

        for i, (input, target) in enumerate(val_loader):
            if torch.cuda.is_available():
                target = target.cuda()
                input = input.cuda()

            # compute output
            output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            prec1 = accuracy(output.data, target, topk=(1,))[0]

            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))

            for k in range(7):
                prec[k].update(precision(output.data, target, target_class=k), input.size(0))
                rec[k].update(recall(output.data, target, target_class=k), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % 10 == 0:
                logger.info('Test: [{0}/{1}]\t'
                        'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                        'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                        'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'
                        .format(
                            i, len(val_loader), batch_time=batch_time, loss=losses,
                            top1=top1))

        printlog(' * epoch: {epoch} Prec@1 {top1.avg:.3f}'.format(epoch=epoch,top1=top1), logger, q)

    return top1.avg, prec, rec


def train(model, train_loader, criterion, optimizer, epoch):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    prec = AverageMeter()

    # switch to train mode
    model.train()
    end = time.time()

    for i, (input, target) in enumerate(train_loader):
        if torch.cuda.is_available():
            target = target.cuda()
            input = input.cuda()
        # compute output
        output = model(input)
        loss = criterion(output, target)
        # measure accuracy and record loss
        prec1 = accuracy(output, target, topk=(1,))[0]

        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))
        
        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % 10 == 0:
            logger.info('Epoch: [{0}][{1}/{2}]\t'
                        'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                        'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                        'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                        .format(
                            epoch, i, len(train_loader), batch_time=batch_time,
                            loss=losses, top1=top1))

for idx, (name, module) in enumerate(model.named_modules()):
    if(idx < 62):
        for param in module.parameters():
            param.requires_grad = False
    else:
        for param in module.parameters():
            param.requires_grad = True

mean, std = get_params(data_path, resize_size)
normalize = transforms.Normalize(mean=[mean[0].item()],
                         std=[std[0].item()])

transform_train = transforms.Compose([
        transforms.Resize((resize_size, resize_size)),      # 가로세로 크기 조정
        transforms.ColorJitter(0.2,0.2,0.2),                # 밝기, 대비, 채도 조정
        transforms.RandomRotation(2),                       # -2~ 2도 만큼 회전
        transforms.RandomAffine(5),                         # affine 변환 (평행사변형이 된다든지, 사다리꼴이 된다든지)
        transforms.RandomCrop(resize_size, padding=2),      # 원본에서 padding을 상하좌우 2로 둔 뒤, 64만큼 자름
        transforms.RandomHorizontalFlip(),                  # Data 변환 좌우 반전
        transforms.Grayscale(),
        transforms.ToTensor(),
        normalize
    ])

transform_test = transforms.Compose([
    transforms.Resize((resize_size, resize_size)),
    transforms.Grayscale(),
    transforms.ToTensor(),
    normalize
])

kwargs = {'num_workers': 16, 'pin_memory': True}

train_data = datasets.ImageFolder(data_path, transform_train)
val_data = datasets.ImageFolder(data_path,transform_test)


num_train = len(train_data)
indices = list(range(num_train))
split = int(np.floor(validation_ratio * num_train))

# 랜덤 시드 설정. (Train이나 ,Test 일때 모두 10 이므로 같은 데이터셋이라 할 수 있다)
np.random.seed(random_seed)
np.random.shuffle(indices)

# Train set, Validation set 나누기.
train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

train_loader = torch.utils.data.DataLoader(
    train_data, batch_size=256, sampler=train_sampler,      #shuffle = True
    **kwargs)
val_loader = torch.utils.data.DataLoader(
    val_data, batch_size=256, sampler=valid_sampler,        #shuffle = False
    **kwargs)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), 0.0001, weight_decay=0.0001)

if torch.cuda.is_available():
    torch.cuda.set_device(gpus[0])
    with torch.cuda.device(gpus[0]):
        model = model.cuda()
        criterion = criterion.cuda()
    model = torch.nn.DataParallel(model, device_ids=gpus, output_device=gpus[0])

checkpoint = torch.load(check_path)

pretrained_dict = checkpoint['state_dict']
new_model_dict = model.state_dict()
for k, v in pretrained_dict.items(): 
    if 'classifier' in k:
        continue
    new_model_dict.update({k : v})
model.load_state_dict(new_model_dict)

#model.load_state_dict(checkpoint['state_dict'], strict=False)
best_prec1 = checkpoint['best_prec1']

for epoch in range(epochs):
   train(model, train_loader, criterion, optimizer, epoch)

   prec1, prec, rec =  validate(val_loader, model, criterion, epoch)

   is_best = prec1 >= best_prec1

   best_prec1 = max(prec1, best_prec1)

   checkpoint = save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
        }, is_best)


for i in range(len(prec)):
    logger.info(' * Precision {prec.avg:.3f}'.format(prec=prec[i]))
    logger.info(' * recall {rec.avg:.3f}'.format(rec=rec[i]))