main.py 6.19 KB
import argparse
import random
import os
import cv2
import logging
import datetime

import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.utils import save_image

from model import mobilenetv3
from utils import get_args_from_yaml, MyImageFolder
from get_mean_std import get_params

## 해당 코드는 전체 inference를 모두 담은 code.

# make Logger
logger = logging.getLogger(os.path.dirname(__name__))
logger.setLevel(logging.INFO)

# make Logger stream
streamHandler = logging.StreamHandler()
logger.addHandler(streamHandler)

if not os.path.exists('eval_results/main'):
    os.mkdir('eval_results/main')

if not os.path.exists('eval_results/main/Normal'):
    os.mkdir('eval_results/main/Normal')

if not os.path.exists('eval_results/main/Crack'):
    os.mkdir('eval_results/main/Crack')

if not os.path.exists('eval_results/main/Empty'):
    os.mkdir('eval_results/main/Empty')

if not os.path.exists('eval_results/main/Flip'):
    os.mkdir('eval_results/main/Flip')

if not os.path.exists('eval_results/main/Pollute'):
    os.mkdir('eval_results/main/Pollute')

if not os.path.exists('eval_results/main/Double'):
    os.mkdir('eval_results/main/Double')

if not os.path.exists('eval_results/main/Leave'):
    os.mkdir('eval_results/main/Leave')

if not os.path.exists('eval_results/main/Scratch'):
    os.mkdir('eval_results/main/Scratch')


def main(Error_args, Error_Type_args):
    logdir = f"logs/main/"
    if not os.path.exists(logdir):
        os.mkdir(logdir)
    fileHander = logging.FileHandler(logdir + f"{datetime.datetime.now().strftime('%Y%m%d-%H:%M:%S')}_log.log")
    logger.addHandler(fileHander)

    run(Error_args, Error_Type_args)

def run(Error_args, Error_Type_args):
    Error_args['checkpoint'] = "output/Error/25678_model=MobilenetV3-ep=3000-block=4/checkpoint.pth.tar"
    Error_Type_args['checkpoint'] = "output/ErrorType/2798_model=MobilenetV3-ep=3000-block=4/checkpoint.pth.tar"

    Error_model = mobilenetv3(n_class= Error_args['model']['class'], blocknum=Error_args['model']['blocks'])
    Error_Type_model = mobilenetv3(n_class=Error_Type_args['model']['class'], blocknum=Error_Type_args['model']['blocks'])

    gpus = Error_args['gpu']
    resize_size = Error_args['train']['size']

    torch.cuda.set_device(gpus[0])
    with torch.cuda.device(gpus[0]):
        Error_model = Error_model.cuda()
        Error_Type_model = Error_Type_model.cuda()
        
    Error_model = torch.nn.DataParallel(Error_model, device_ids=gpus, output_device=gpus[0])
    Error_Type_model = torch.nn.DataParallel(Error_Type_model, device_ids=gpus, output_device=gpus[0])

    Error_checkpoint = torch.load(Error_args['checkpoint'])
    Error_Type_checkpoint = torch.load(Error_Type_args['checkpoint'])

    Error_model.load_state_dict(Error_checkpoint['state_dict'])
    Error_Type_model.load_state_dict(Error_Type_checkpoint['state_dict'])

    mean, std = get_params(Error_args['data']['test'], resize_size)
    normalize = transforms.Normalize(mean=[mean[0].item()],
                         std=[std[0].item()])

    transform = transforms.Compose([
        transforms.Resize((resize_size, resize_size)),
        transforms.Grayscale(),
        transforms.ToTensor(),
        normalize
    ])

    dataset = MyImageFolder(Error_args['data']['test'],  transform)

    print(len(dataset))

    loader = torch.utils.data.DataLoader(
        dataset, batch_size=Error_args['predict']['batch-size'], shuffle=False, 
        num_workers=Error_args['predict']['worker'], pin_memory=True
    )
    
    for data in loader:
        (input, _), (path, _) = data
        input= input.cuda()

        output = Error_model(input)
        _, output = output.topk(1 ,1 ,True,True)

        error_cases = torch.ones((1,1,64,64)).cuda()
        new_paths = []

        error = 0
        normal = 0
        for idx in range(input.shape[0]):
            # if Error Case
     
            if output[idx] == 0:
                error_cases = torch.cat((error_cases, input[idx:idx+1]), dim=0)
                new_paths.append(path[idx])
                error = error +1
            # Normal Case
            else:
                img = cv2.imread(path[idx])
                cv2.imwrite(f"eval_results/main/Normal/{path[idx].split('/')[-1]}", img)
                normal = normal+1   
        
        print(f"error path : {len(new_paths)}")
        print(f"error : {error}")
        print(f"normal : {normal}")

        error_cases = error_cases[1:]
        print(error_cases.shape[0])
        
        output = Error_Type_model(error_cases)
        _, output = output.topk(1 ,1 ,True,True)

        for idx in range(error_cases.shape[0]):
            # Crack
            if output[idx] == 0:
                img = cv2.imread(new_paths[idx])
                cv2.imwrite(f"eval_results/main/Crack/{new_paths[idx].split('/')[-1]}", img)

            # Double
            elif output[idx] == 1:
                img = cv2.imread(new_paths[idx])
                cv2.imwrite(f"eval_results/main/Double/{new_paths[idx].split('/')[-1]}", img)

            # Empty
            elif output[idx] == 2:
                img = cv2.imread(new_paths[idx])
                cv2.imwrite(f"eval_results/main/Empty/{new_paths[idx].split('/')[-1]}", img)

            # Flip
            elif output[idx] == 3:
                img = cv2.imread(new_paths[idx])
                cv2.imwrite(f"eval_results/main/Flip/{new_paths[idx].split('/')[-1]}", img)

            # Leave
            elif output[idx] == 4:
                img = cv2.imread(new_paths[idx])
                cv2.imwrite(f"eval_results/main/Leave/{new_paths[idx].split('/')[-1]}", img)

            # Pollute
            elif output[idx] == 5:
                img = cv2.imread(new_paths[idx])
                cv2.imwrite(f"eval_results/main/Pollute/{new_paths[idx].split('/')[-1]}", img)

            # Scratch
            elif output[idx] == 6:
                img = cv2.imread(new_paths[idx])
                cv2.imwrite(f"eval_results/main/Scratch/{new_paths[idx].split('/')[-1]}", img)
            

if __name__ == '__main__':
    Error_args = get_args_from_yaml("configs/Error_config.yml")
    Error_Type_args = get_args_from_yaml("configs/ErrorType_config.yml")
    main(Error_args, Error_Type_args)