get_threshold.py 4.4 KB
import os
import time
import sys
import torch.nn.functional as F

import numpy as np
import PIL
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import yaml
import cv2
from get_mean_std import get_params
sys.path.append(os.path.join(os.path.dirname(__name__)))
from model import mobilenetv3

if not os.path.exists("threshold"):
    os.mkdir("threshold")

thresholds = [.05, .1, .15, .2, .25, .3, .35, .4, .45, .5]

for threshold in thresholds:
    if not os.path.exists(f"threshold/{threshold}"):
        os.mkdir(f"threshold/{threshold}")


def get_args_from_yaml(file='trainer/configs/Error_config.yml'):
    with open(file) as f:
        conf = yaml.load(f)
    return conf

class MyImageFolder(datasets.ImageFolder):
    def __getitem__(self, index):
        # return image path
        return super(MyImageFolder, self).__getitem__(index), self.imgs[index]

def main(args):
    run_model(args)
    print(f"[{args['id']}] done")

def run_model(args):
    resize_size = args['train']['size']

    gpus = args['gpu']

    mean, std = get_params(args['data']['train'], resize_size)

    normalize = transforms.Normalize(mean=[mean[0].item()],
                         std=[std[0].item()])

    normalize_factor = [mean, std]

    # data loader
    transform_test = transforms.Compose([
        transforms.Resize((resize_size,resize_size)),
        transforms.Grayscale(),
        transforms.ToTensor(),
        normalize
    ])
    kwargs = {'num_workers': args['predict']['worker'], 'pin_memory': True}
    test_data = MyImageFolder(args['data']['val'], transform_test)
    val_loader = torch.utils.data.DataLoader(
        test_data, batch_size=args['predict']['batch-size'], shuffle=False,
        **kwargs)

    # load model
    model = mobilenetv3(n_class= args['model']['class'], blocknum= args['model']['blocks'])

    torch.cuda.set_device(gpus[0])
    with torch.cuda.device(gpus[0]):
        model = model.cuda()

    model = torch.nn.DataParallel(model, device_ids=gpus, output_device=gpus[0])

    print("=> loading checkpoint '{}'".format(args['checkpoint']))
    checkpoint = torch.load(args['checkpoint'])
    model.load_state_dict(checkpoint['state_dict'])
    print("=> loaded checkpoint '{}' (epoch {})"
          .format(args['checkpoint'], checkpoint['epoch']))
    cudnn.benchmark = True

    extract_data(val_loader, model, normalize_factor, args)


def extract_data(val_loader, model, normalize_factor, args):
    with torch.no_grad():
        # switch to evaluate mode
        model.eval()
        for data in(val_loader):
            (input, target), (path , _) = data
            target = target.cuda()
            input = input.cuda()

            output = model(input)

            print("save data!")
            save_data(output, target, path)

class AverageMeter(object):
    def __init__(self):
        self.reset()
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

def save_data(output, target, path):
    n_digits = 3
    prob = F.softmax(output, dim=1)
    prob = torch.round(prob * 10**n_digits) / (10**n_digits)
    for idx, p in enumerate(prob):
        value = torch.topk(p, 2).values
        indice = torch.topk(p,2).indices
        
        value = value.tolist()
        indice = indice.tolist()

        gap = abs(value[0]-value[1])
        for threshold in thresholds:
            if(gap < threshold):
                img = cv2.imread(path[idx])
                filename = path[idx].split('/')[-1]
                cv2.imwrite(f'threshold/{threshold}/pred_{indice[0]}_{indice[1]}_{filename}', img)

if __name__ == '__main__':
    args = get_args_from_yaml('configs/All_config.yml')
    args['config'] = 'All'
    args['id'] = 'threshold'
    main(args)