test_shapenet_modify.ipynb 12.7 KB
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "valid.lmdb를 대상으로 \n",
    "1. cd, emd cost function 값 확인\n",
    "2. 각 표본에 대한 결과 출력\n",
    "3. pcd값 저장해둬서 어떤 결과인지 직접 확인하자 - (이게 발표자료로서 의미가 있을것 같음)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "            num_eval_steps = num_valid // args.batch_size\n",
    "            total_loss = 0\n",
    "            total_time = 0\n",
    "            sess.run(tf.local_variables_initializer())\n",
    "            for i in range(num_eval_steps):\n",
    "                start = time.time()\n",
    "                ids, inputs, npts, gt = next(valid_gen)\n",
    "                feed_dict = {inputs_pl: inputs,my_inputs_pl:my_inputs, npts_pl: npts, gt_pl: gt, is_training_pl: False}\n",
    "                loss, _ = sess.run([model.loss, model.update], feed_dict=feed_dict)\n",
    "                total_loss += loss\n",
    "                total_time += time.time() - start\n",
    "            summary = sess.run(valid_summary, feed_dict={is_training_pl: False})\n",
    "            writer.add_summary(summary, step)\n",
    "            print(colored('epoch %d  step %d  loss %.8f - time per batch %.4f' %\n",
    "                          (epoch, step, total_loss / num_eval_steps, total_time / num_eval_steps),\n",
    "                          'grey', 'on_green'))\n",
    "            total_time = 0\n",
    "            if step % args.steps_per_visu == 0:\n",
    "                all_pcds = sess.run(model.visualize_ops, feed_dict=feed_dict)\n",
    "                for i in range(0, args.batch_size, args.visu_freq):\n",
    "                    plot_path = os.path.join(args.log_dir, 'plots',\n",
    "                                            'epoch_%d_step_%d_%s.png' % (epoch, step, ids[i]))\n",
    "                    pcds = [x[i] for x in all_pcds]\n",
    "                    plot_pcd_three_views(plot_path, pcds, model.visualize_titles)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /tf/tensorflow-tutorials/pcn_modify/pcn/models/pcn_emd.py:21: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.\n",
      "\n",
      "WARNING:tensorflow:From /tf/tensorflow-tutorials/pcn_modify/pcn/models/pcn_emd.py:21: The name tf.AUTO_REUSE is deprecated. Please use tf.compat.v1.AUTO_REUSE instead.\n",
      "\n",
      "WARNING:tensorflow:\n",
      "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
      "For more information, please see:\n",
      "  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
      "  * https://github.com/tensorflow/addons\n",
      "  * https://github.com/tensorflow/io (for I/O related ops)\n",
      "If you depend on functionality not listed there, please file an issue.\n",
      "\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_core/contrib/layers/python/layers/layers.py:1057: Layer.apply (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `layer.__call__` method instead.\n",
      "WARNING:tensorflow:From /tf/tensorflow-tutorials/pcn_modify/pcn/tf_util.py:71: The name tf.summary.scalar is deprecated. Please use tf.compat.v1.summary.scalar instead.\n",
      "\n",
      "WARNING:tensorflow:From /tf/tensorflow-tutorials/pcn_modify/pcn/tf_util.py:75: The name tf.metrics.mean is deprecated. Please use tf.compat.v1.metrics.mean instead.\n",
      "\n",
      "INFO:tensorflow:Restoring parameters from ./log/pcn_emd_car_modify/model-26000\n",
      "Average time: 0.049774\n",
      "Average Chamfer distance: 0.009361\n",
      "Average Earth mover distance: 0.051862\n",
      "Chamfer distance per category\n",
      "02958343 0.009361\n",
      "Earth mover distance per category\n",
      "02958343 0.051862\n"
     ]
    }
   ],
   "source": [
    "# Author: Wentao Yuan (wyuan1@cs.cmu.edu) 05/31/2018\n",
    "\n",
    "import argparse\n",
    "import csv\n",
    "import importlib\n",
    "import models\n",
    "import numpy as np\n",
    "import os\n",
    "import tensorflow as tf\n",
    "import time\n",
    "\n",
    "from tf_util import chamfer, earth_mover\n",
    "from visu_util import plot_pcd_three_views\n",
    "from data_util import lmdb_dataflow, get_queued_data, resample_pcd\n",
    "\n",
    "##################\n",
    "#from io_util import read_pcd, save_pcd 이부분 그냥 포함시켜버렸다...ㅎㅎ\n",
    "import numpy as np\n",
    "#from open3d import *\n",
    "import open3d as o3d\n",
    "\n",
    "\n",
    "def read_pcd(filename):\n",
    "    pcd = o3d.io.read_point_cloud(filename)\n",
    "    return np.array(pcd.points)\n",
    "\n",
    "\n",
    "def save_pcd(filename, points):\n",
    "    pcd = o3d.geometry.PointCloud()\n",
    "    pcd.points = o3d.utility.Vector3dVector(points)\n",
    "    o3d.io.write_point_cloud(filename, pcd)\n",
    "##################\n",
    "\n",
    "list_path ='../../data/shapenet/car_test.list' \n",
    "data_dir = '../../data/shapenet/test'\n",
    "model_type = 'pcn_emd'\n",
    "checkpoint = './log/pcn_emd_car_modify'\n",
    "results_dir ='results/shapenet_pcn_emd_car_modify'\n",
    "num_gt_points = 16384\n",
    "plot_freq = 1\n",
    "_save_pcd = True\n",
    "lmdb_valid = ''\n",
    "num_input_points=3000\n",
    "\n",
    "def test():\n",
    "    inputs = tf.placeholder(tf.float32, (1, None, 3))\n",
    "    my_inputs = tf.placeholder(tf.float32, (1, None, 3))\n",
    "    npts = tf.placeholder(tf.int32, (1,))\n",
    "    gt = tf.placeholder(tf.float32, (1, num_gt_points, 3))\n",
    "    model_module = importlib.import_module('.%s' % model_type, 'models')\n",
    "    model = model_module.Model(inputs,my_inputs, npts, gt, tf.constant(1.0))\n",
    "\n",
    "    output = tf.placeholder(tf.float32, (1, num_gt_points, 3))\n",
    "    cd_op = chamfer(output, gt)\n",
    "    emd_op = earth_mover(output, gt)\n",
    "\n",
    "    #####\n",
    "    df_valid, num_valid = lmdb_dataflow(\n",
    "        lmdb_valid, 1, num_input_points, num_gt_points, is_training=False)\n",
    "    valid_gen = df_valid.get_data()\n",
    "    #####\n",
    "    \n",
    "    config = tf.ConfigProto()\n",
    "    config.gpu_options.allow_growth = True\n",
    "    config.allow_soft_placement = True\n",
    "    sess = tf.Session(config=config)\n",
    "\n",
    "    saver = tf.train.Saver()\n",
    "    saver.restore(sess, tf.train.latest_checkpoint(checkpoint))\n",
    "\n",
    "    os.makedirs(results_dir, exist_ok=True)\n",
    "    csv_file = open(os.path.join(results_dir, 'results.csv'), 'w') # 각 항목별로 cd, emd 구해줌.\n",
    "    writer = csv.writer(csv_file)\n",
    "    writer.writerow(['id', 'cd', 'emd'])\n",
    "\n",
    "    with open(list_path) as file:\n",
    "        model_list = file.read().splitlines()\n",
    "    total_time = 0\n",
    "    total_cd = 0\n",
    "    total_emd = 0\n",
    "    cd_per_cat = {}\n",
    "    emd_per_cat = {}\n",
    "    for i, model_id in enumerate(model_list):\n",
    "        partial = read_pcd(os.path.join(data_dir, 'partial', '%s.pcd' % model_id))\n",
    "        complete = read_pcd(os.path.join(data_dir, 'complete', '%s.pcd' % model_id))\n",
    "        start = time.time()\n",
    "        completion = sess.run(model.outputs, feed_dict={inputs: [partial],my_inputs:[partial], npts: [partial.shape[0]]})\n",
    "        total_time += time.time() - start\n",
    "        cd, emd = sess.run([cd_op, emd_op], feed_dict={output: completion, gt: [complete]})\n",
    "        total_cd += cd\n",
    "        total_emd += emd\n",
    "        writer.writerow([model_id, cd, emd]) #항목별 cd,emd 써줌\n",
    "\n",
    "        # 카테고리별 cd,emd 얻음\n",
    "        synset_id, model_id = model_id.split('/')\n",
    "        if not cd_per_cat.get(synset_id):\n",
    "            cd_per_cat[synset_id] = []\n",
    "        if not emd_per_cat.get(synset_id):\n",
    "            emd_per_cat[synset_id] = []\n",
    "        cd_per_cat[synset_id].append(cd)\n",
    "        emd_per_cat[synset_id].append(emd)\n",
    "        \n",
    "        # 3가지 view에서 모델 input,gt,output보여줌.\n",
    "        if i % plot_freq == 0:\n",
    "            os.makedirs(os.path.join(results_dir, 'plots', synset_id), exist_ok=True)\n",
    "            plot_path = os.path.join(results_dir, 'plots', synset_id, '%s.png' % model_id)\n",
    "            plot_pcd_three_views(plot_path, [partial, completion[0], complete],\n",
    "                                 ['input', 'output', 'ground truth'],\n",
    "                                 'CD %.4f  EMD %.4f' % (cd, emd),\n",
    "                                 [5, 0.5, 0.5])\n",
    "        if _save_pcd:\n",
    "            os.makedirs(os.path.join(results_dir, 'pcds', synset_id), exist_ok=True)\n",
    "            save_pcd(os.path.join(results_dir, 'pcds', '%s.pcd' % model_id), completion[0])\n",
    "    csv_file.close()\n",
    "    sess.close()\n",
    "\n",
    "    print('Average time: %f' % (total_time / len(model_list)))\n",
    "    print('Average Chamfer distance: %f' % (total_cd / len(model_list)))\n",
    "    print('Average Earth mover distance: %f' % (total_emd / len(model_list)))\n",
    "    print('Chamfer distance per category')\n",
    "    for synset_id in cd_per_cat.keys():\n",
    "        print(synset_id, '%f' % np.mean(cd_per_cat[synset_id]))\n",
    "    print('Earth mover distance per category')\n",
    "    for synset_id in emd_per_cat.keys():\n",
    "        print(synset_id, '%f' % np.mean(emd_per_cat[synset_id]))\n",
    "'''\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    parser = argparse.ArgumentParser()\n",
    "    parser.add_argument('--list_path', default='data/shapenet/test.list')\n",
    "    parser.add_argument('--data_dir', default='data/shapenet/test')\n",
    "    parser.add_argument('--model_type', default='pcn_emd')\n",
    "    parser.add_argument('--checkpoint', default='data/trained_models/pcn_emd')\n",
    "    parser.add_argument('--results_dir', default='results/shapenet_pcn_emd')\n",
    "    parser.add_argument('--num_gt_points', type=int, default=16384)\n",
    "    parser.add_argument('--plot_freq', type=int, default=100)\n",
    "    parser.add_argument('--save_pcd', action='store_true')\n",
    "    args = parser.parse_args()\n",
    "\n",
    "    test(args)\n",
    "'''\n",
    "test()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "read_point_cloud(): incompatible function arguments. The following argument types are supported:\n    1. (filename: str, format: str = 'auto', remove_nan_points: bool = True, remove_infinite_points: bool = True, print_progress: bool = False) -> open3d.open3d.geometry.PointCloud\n\nInvoked with: ",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-9-77d1fb1ef881>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mopen3d\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mio\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread_point_cloud\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m: read_point_cloud(): incompatible function arguments. The following argument types are supported:\n    1. (filename: str, format: str = 'auto', remove_nan_points: bool = True, remove_infinite_points: bool = True, print_progress: bool = False) -> open3d.open3d.geometry.PointCloud\n\nInvoked with: "
     ]
    }
   ],
   "source": [
    "from open3d import *\n",
    "io.read_point_cloud()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}