test_with_lmdb.ipynb 9.18 KB
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[32m[0603 12:20:31 @format.py:92]\u001b[0m Found 100 entries in ../../../data/shapenet-car/valid.lmdb\n",
      "INFO:tensorflow:Restoring parameters from ./log/pcn_addbeta_lr/model-81500\n",
      "Average Chamfer distance: 0.009029\n",
      "Average Earth mover distance: 0.053403\n"
     ]
    }
   ],
   "source": [
    "# Author: Wentao Yuan (wyuan1@cs.cmu.edu) 05/31/2018\n",
    "\n",
    "import argparse\n",
    "import csv\n",
    "import importlib\n",
    "import models\n",
    "import numpy as np\n",
    "import os\n",
    "import tensorflow as tf\n",
    "import time\n",
    "\n",
    "from tf_util import chamfer, earth_mover\n",
    "from visu_util import plot_pcd_three_views\n",
    "from data_util import lmdb_dataflow, get_queued_data\n",
    "\n",
    "##################\n",
    "#from io_util import read_pcd, save_pcd 이부분 그냥 포함시켜버렸다...ㅎㅎ\n",
    "import numpy as np\n",
    "#from open3d import *\n",
    "import open3d as o3d\n",
    "\n",
    "\n",
    "def read_pcd(filename):\n",
    "    pcd = o3d.io.read_point_cloud(filename)\n",
    "    return np.array(pcd.points)\n",
    "\n",
    "\n",
    "def save_pcd(filename, points):\n",
    "    pcd = o3d.geometry.PointCloud()\n",
    "    pcd.points = o3d.utility.Vector3dVector(points)\n",
    "    o3d.io.write_point_cloud(filename, pcd)\n",
    "##################\n",
    "\n",
    "#list_path ='../../data/shapenet/car_test.list' \n",
    "#data_dir = '../../data/shapenet/test'\n",
    "model_type = 'pcn_emd'\n",
    "checkpoint = './log/pcn_addbeta_lr'\n",
    "results_dir ='results/pcn_addbeta_lr'\n",
    "num_gt_points = 16384\n",
    "plot_freq = 1\n",
    "_save_pcd = True\n",
    "lmdb_valid='../../../data/shapenet-car/valid.lmdb'\n",
    "num_input_points=3000\n",
    "\n",
    "def test():\n",
    "    inputs = tf.placeholder(tf.float32, (1, None, 3))\n",
    "    my_inputs = tf.placeholder(tf.float32, (1, None, 3))\n",
    "    npts = tf.placeholder(tf.int32, (1,))\n",
    "    gt = tf.placeholder(tf.float32, (1, num_gt_points, 3))\n",
    "    model_module = importlib.import_module('.%s' % model_type, 'models')\n",
    "    model = model_module.Model(inputs,my_inputs, npts, gt, tf.constant(1.0),tf.constant(1.0))\n",
    "\n",
    "    output = tf.placeholder(tf.float32, (1, num_gt_points, 3))\n",
    "    cd_op = chamfer(output, gt)\n",
    "    emd_op = earth_mover(output, gt)\n",
    "\n",
    "    ###\n",
    "    df_valid, num_valid = lmdb_dataflow(\n",
    "        lmdb_valid, 1, num_input_points, num_gt_points, is_training=False)\n",
    "    valid_gen = df_valid.get_data()\n",
    "    \n",
    "    config = tf.ConfigProto()\n",
    "    config.gpu_options.allow_growth = True\n",
    "    config.allow_soft_placement = True\n",
    "    sess = tf.Session(config=config)\n",
    "\n",
    "    saver = tf.train.Saver()\n",
    "    saver.restore(sess, tf.train.latest_checkpoint(checkpoint))\n",
    "\n",
    "    os.makedirs(results_dir, exist_ok=True)\n",
    "    csv_file = open(os.path.join(results_dir, 'results.csv'), 'w') # 각 항목별로 cd, emd 구해줌.\n",
    "    writer = csv.writer(csv_file)\n",
    "    writer.writerow(['id', 'cd', 'emd'])\n",
    "\n",
    "    ###\n",
    "    total_time = 0\n",
    "    total_cd = 0\n",
    "    total_emd = 0\n",
    "    for i in range(num_valid):\n",
    "        ids,iinputs,inpts,igt = next(valid_gen)\n",
    "\n",
    "        completion = sess.run(model.outputs, feed_dict={inputs:iinputs, my_inputs:iinputs, npts:inpts})\n",
    "        cd,emd = sess.run([cd_op,emd_op],feed_dict={output: completion, gt:igt})\n",
    "        total_cd +=cd\n",
    "        total_emd +=emd\n",
    "        writer.writerow([ids,cd,emd]) #항목별 cd,emd\n",
    "        \n",
    "        if i % plot_freq == 0:\n",
    "            os.makedirs(os.path.join(results_dir, 'plots'), exist_ok=True)\n",
    "            plot_path = os.path.join(results_dir, 'plots', '%s.png' % ids)\n",
    "#            print(iinputs.shape,completion[0].shape,igt.shape)###\n",
    "\n",
    "            plot_pcd_three_views(plot_path, [iinputs.reshape(iinputs.shape[1],iinputs.shape[2]), completion[0], igt.reshape(igt.shape[1],igt.shape[2])],\n",
    "                                 ['input', 'output', 'ground truth'],\n",
    "                                 'CD %.4f  EMD %.4f' % (cd, emd),\n",
    "                                 [5, 0.5, 0.5])\n",
    "        if _save_pcd:\n",
    "            os.makedirs(os.path.join(results_dir, 'pcds'), exist_ok=True)\n",
    "            save_pcd(os.path.join(results_dir, 'pcds', '%s.pcd' % ids), completion[0])\n",
    "    csv_file.close()\n",
    "    sess.close()\n",
    "\n",
    "    print('Average Chamfer distance: %f' % (total_cd / num_valid))\n",
    "    print('Average Earth mover distance: %f' % (total_emd / num_valid))\n",
    "######################\n",
    "'''    \n",
    "    with open(list_path) as file:\n",
    "        model_list = file.read().splitlines()\n",
    "    total_time = 0\n",
    "    total_cd = 0\n",
    "    total_emd = 0\n",
    "    cd_per_cat = {}\n",
    "    emd_per_cat = {}\n",
    "    for i, model_id in enumerate(model_list):\n",
    "        partial = read_pcd(os.path.join(data_dir, 'partial', '%s.pcd' % model_id))\n",
    "        complete = read_pcd(os.path.join(data_dir, 'complete', '%s.pcd' % model_id))\n",
    "        start = time.time()\n",
    "        completion = sess.run(model.outputs, feed_dict={inputs: [partial],my_inputs:[partial], npts: [partial.shape[0]]})\n",
    "        total_time += time.time() - start\n",
    "        cd, emd = sess.run([cd_op, emd_op], feed_dict={output: completion, gt: [complete]})\n",
    "        total_cd += cd\n",
    "        total_emd += emd\n",
    "        writer.writerow([model_id, cd, emd]) #항목별 cd,emd 써줌\n",
    "\n",
    "        # 카테고리별 cd,emd 얻음\n",
    "        synset_id, model_id = model_id.split('/')\n",
    "        if not cd_per_cat.get(synset_id):\n",
    "            cd_per_cat[synset_id] = []\n",
    "        if not emd_per_cat.get(synset_id):\n",
    "            emd_per_cat[synset_id] = []\n",
    "        cd_per_cat[synset_id].append(cd)\n",
    "        emd_per_cat[synset_id].append(emd)\n",
    "        \n",
    "        # 3가지 view에서 모델 input,gt,output보여줌.\n",
    "        if i % plot_freq == 0:\n",
    "            os.makedirs(os.path.join(results_dir, 'plots', synset_id), exist_ok=True)\n",
    "            plot_path = os.path.join(results_dir, 'plots', synset_id, '%s.png' % model_id)\n",
    "            plot_pcd_three_views(plot_path, [partial, completion[0], complete],\n",
    "                                 ['input', 'output', 'ground truth'],\n",
    "                                 'CD %.4f  EMD %.4f' % (cd, emd),\n",
    "                                 [5, 0.5, 0.5])\n",
    "        if _save_pcd:\n",
    "            os.makedirs(os.path.join(results_dir, 'pcds', synset_id), exist_ok=True)\n",
    "            save_pcd(os.path.join(results_dir, 'pcds', '%s.pcd' % model_id), completion[0])\n",
    "    csv_file.close()\n",
    "    sess.close()\n",
    "\n",
    "    print('Average time: %f' % (total_time / len(model_list)))\n",
    "    print('Average Chamfer distance: %f' % (total_cd / len(model_list)))\n",
    "    print('Average Earth mover distance: %f' % (total_emd / len(model_list)))\n",
    "    print('Chamfer distance per category')\n",
    "    for synset_id in cd_per_cat.keys():\n",
    "        print(synset_id, '%f' % np.mean(cd_per_cat[synset_id]))\n",
    "    print('Earth mover distance per category')\n",
    "    for synset_id in emd_per_cat.keys():\n",
    "        print(synset_id, '%f' % np.mean(emd_per_cat[synset_id]))\n",
    "        '''\n",
    "'''\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    parser = argparse.ArgumentParser()\n",
    "    parser.add_argument('--list_path', default='data/shapenet/test.list')\n",
    "    parser.add_argument('--data_dir', default='data/shapenet/test')\n",
    "    parser.add_argument('--model_type', default='pcn_emd')\n",
    "    parser.add_argument('--checkpoint', default='data/trained_models/pcn_emd')\n",
    "    parser.add_argument('--results_dir', default='results/shapenet_pcn_emd')\n",
    "    parser.add_argument('--num_gt_points', type=int, default=16384)\n",
    "    parser.add_argument('--plot_freq', type=int, default=100)\n",
    "    parser.add_argument('--save_pcd', action='store_true')\n",
    "    args = parser.parse_args()\n",
    "\n",
    "    test(args)\n",
    "'''\n",
    "test()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}