2019년 11월 25일 월요일

JetsonNano - Human Pose estimation using tensorflow (Boost up the FPS using TensorRT)

I used Jetson Nano, Ubuntu 18.04 Official image with root account. Please read my previous article first.

Prerequisites

Before you build "ildoonet/tf-pose-estimation", you must pre install these packages. See the URLs.

Performance Compare


I used a lightweight Pose Estimation Tensorflow framework (https://github.com/ildoonet/tf-pose-estimation) in my previous post. The framework provides four network models and can be selected to match the performance of the machines used. Network models are stored in or will be stored in the models / graph directory.

The following table compares performance with and without TensorRT.

I execute these commands to get the above result.


python3 run_webcam.py --model=mobilenet_thin --resize=368x368
python3 run_webcam.py --model=mobilenet_large --resize=368x368
python3 run_webcam.py --model=mobilenet_v2_large --resize=368x368
python3 run_webcam.py --model=mobilenet_v2_small --resize=368x368
python3 run_webcam.py --model=mobilenet_thin --resize=368x368  --tensorrt=True
python3 run_webcam.py --model=mobilenet_large --resize=368x368  --tensorrt=True
python3 run_webcam.py --model=mobilenet_v2_large --resize=368x368  --tensorrt=True
python3 run_webcam.py --model=mobilenet_v2_small --resize=368x368  --tensorrt=True

python3 run_webcam.py --model=mobilenet_thin --resize=160x160
python3 run_webcam.py --model=mobilenet_large --resize=160x160
python3 run_webcam.py --model=mobilenet_v2_large --resize=160x160
python3 run_webcam.py --model=mobilenet_v2_small --resize=160x160
python3 run_webcam.py --model=mobilenet_thin --resize=160x160  --tensorrt=True
python3 run_webcam.py --model=mobilenet_large --resize=160x160  --tensorrt=True
python3 run_webcam.py --model=mobilenet_v2_large --resize=160x160  --tensorrt=True
python3 run_webcam.py --model=mobilenet_v2_small --resize=160x160  --tensorrt=True


Under the hood


Network model load time

The larger the network model, the longer it takes to load the model in TensorFlow. And with TensorRT, this time is even longer. However, with TensorRT, you can save time processing inference images once you finish reading the network model. In particular, large and complex models such as CMUs can produce large performance differences.

TensorRT network model processing

The source codes at https://github.com/ildoonet/tf-pose-estimation generate a new model for TensorRT every time when using TensorRT. If the model for Tensorflow is converted to TensorRT in advance, unnecessary steps can be omitted. The Python code that converts for TensorRT in real time is in the tf_pose / estimator.py file.


<tf_pose / estimator.py>


The create_inference_graph function takes a model for TensorFlow as an input parameter and creates a model for TensorRT. Create a network model for TensorRT in advance and modify it so that the model can be read immediately.

Be careful :Models for TensorRT are incompatible with different versions of TensorRT. Therefore, it is the safest way to save and work on the Jetson Series you are currently using.


Convert Tensorflow model for TensorRT

I made the following Python code to convert the Tensorflow model for TensorRT.


import argparse
import sys, os
import time
import tensorflow as tf
import tensorflow.contrib.tensorrt as trt
#from tf_pose import common
#import cv2


def get_frozen_graph(graph_file):
  """Read Frozen Graph file from disk."""
  with tf.gfile.FastGFile(graph_file, "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
  return graph_def


parser = argparse.ArgumentParser(description='tf network model conversion to tensorrt')
parser.add_argument('--model', type=str, default='mobilenet_v2_small',
                        help='cmu / mobilenet_thin / mobilenet_v2_large / mobilenet_v2_small')
args = parser.parse_args()

model_dir = 'models/graph/'
frozen_name = model_dir + args.model + '/graph_opt.pb'
frozen_graph = get_frozen_graph(frozen_name)
print('=======Frozen Name:%s======='%(frozen_name));
# convert (optimize) frozen model to TensorRT model
your_outputs = ["Openpose/concat_stage7"]

start = time.time()
trt_graph = trt.create_inference_graph(
    input_graph_def=frozen_graph,# frozen model
    outputs=your_outputs,
    is_dynamic_op=True,
    minimum_segment_size=3,
    maximum_cached_engines=int(1e3),
    max_batch_size=1,# specify your max batch size
    max_workspace_size_bytes=2*(10**9),# specify the max workspace
    precision_mode="FP16") # precision, can be "FP32" (32 floating point precision) or "FP16"

elapsed = time.time() - start
print('Tensorflow model => TensorRT model takes : %f'%(elapsed));

#write the TensorRT model to be used later for inference
rt_name = model_dir + args.model + '/graph_opt_rt.pb'
with tf.gfile.FastGFile(rt_name , 'wb') as f:
    f.write(trt_graph.SerializeToString())
<tf_model_2_rt.py>

Run the code to make TensorRT model. If successful, you can see the screen like this.


python3 tf_model_2_rt.py --model=cmu

2019-11-24 23:52:26.250140: I tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc:733] Number of TensorRT candidate segments: 1
2019-11-24 23:52:27.198957: I tensorflow/stream_executor/platform/default/dso_loader.cc:42] Successfully opened dynamic library libcudart.so.10.0
2019-11-24 23:52:28.314413: I tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc:835] TensorRT node TRTEngineOp_0 added for segment 0 consisting of 470 nodes succeeded.
2019-11-24 23:52:28.492216: I tensorflow/core/grappler/optimizers/meta_optimizer.cc:739] Optimization results for grappler item: tf_graph
2019-11-24 23:52:28.492379: I tensorflow/core/grappler/optimizers/meta_optimizer.cc:741]   constant folding: Graph size after: 468 nodes (0), 478 edges (0), time = 1715.36804ms.
2019-11-24 23:52:28.492444: I tensorflow/core/grappler/optimizers/meta_optimizer.cc:741]   layout: Graph size after: 478 nodes (10), 484 edges (6), time = 240.81ms.
2019-11-24 23:52:28.492489: I tensorflow/core/grappler/optimizers/meta_optimizer.cc:741]   constant folding: Graph size after: 473 nodes (-5), 484 edges (0), time = 505.675ms.
2019-11-24 23:52:28.492532: I tensorflow/core/grappler/optimizers/meta_optimizer.cc:741]   TensorRTOptimizer: Graph size after: 4 nodes (-469), 4 edges (-480), time = 2957.31ms.
Tensorflow model => TensorRT model takes : 16.847830

And there's a new pb file(graph_opt_rt.pb) at model/graph/cmu  directory. Run this python for other network models to make TensorRT models.



root@spytx-desktop:/work/src/pose_estimation/tf-pose-estimation# ls -al models/graph/cmu/
total 613232
drwxr-xr-x 2 root root      4096 11월 24 00:21 .
drwxr-xr-x 6 root root      4096 11월  9 21:07 ..
-rw-r--r-- 1 root root       643 11월  9 21:07 download.sh
-rw-r--r-- 1 root root 209299198 11월 10 13:31 graph_opt.pb
-rw-r--r-- 1 root root 418623762 11월 24 23:52 graph_opt_rt.pb
-rw-r--r-- 1 root root         0 11월  9 21:07 __init__.py


Modify the estimator.py to use a ready made TensorRT models.

I modified the estimator.py to open a ready made TensorRT models.


You can download the estimator.py file(Only used on Jetson series) at my repo.


Wrapping up

Processing complex and large network models using TensorRT can provide significant speedups. In Pose Estimation, it can be seen that the cmu model has a large performance gain when using TensorRT. Soon, I will test the above results on TX2. And I will cover the conversion of a generic Tensorflow network model for TensorRT soon.

If you want the most satisfactory human pose estimation performance on Jetson Nano, see the following article(https://spyjetson.blogspot.com/2019/12/jetsonnano-human-pose-estimation-using.html). NVIDIA team introduces human pose estimation using models optimized for TensorRT.





댓글 없음:

댓글 쓰기