As for PoseEstimation in PyTorch, I wrote a blog(https://spyjetson.blogspot.com/2019/10/jetsonnano-human-pose-estimation-using_16.html) last year. The conclusion at this time was that PyTorch's PoseEstimation using the ResNet50 model is excellent in accuracy, but difficult to use in the Jetson Nano due to its low performance of 0.2FPS. In this article, I will look at the performance of Detectron2's Pose Estimation and compare it to the performance introduced in the previous post.
Since Detectron2 has little difference in usage even when using various models, PoseEstimation is quite easy to understand if you look at the Segmentation contents introduced in the previous blog.
Prerequisites
This article assumes that Jetson Nano uses JetPack 4.4 DP or higher and PyTorch 1.5.0 or higher and torchvision 0.6.0 or higher.
Keypoint Detection
At the detectron2's github page https://github.com/facebookresearch/detectron2 , You can see the keypoint detection models supported by Detectron2 by looking in the config directory.
In addition to the rcnn-Resnet-50 model used in last year's article, the rcnn-Resnet101 and rcnn-X101 models have been added to Detectron2.
Let's modify the example used in Detectron2 Segmentation and run the example created for Keypoint Detection.
from detectron2 import model_zoo from detectron2.engine import DefaultPredictor from detectron2.config import get_cfg from detectron2.utils.visualizer import Visualizer from detectron2.data import MetadataCatalog import cv2 import numpy as np import requests, sys, time, os from PIL import Image, ImageDraw import argparse COLORS = [(0, 45, 74, 224), (85, 32, 98, 224), (93, 69, 12, 224), (49, 18, 55, 224), (46, 67, 18, 224), (30, 74, 93, 224)] help = 'Base-Keypoint-RCNN-FPN' help += ',keypoint_rcnn_R_101_FPN_3x' help += ',keypoint_rcnn_R_50_FPN_1x' help += ',keypoint_rcnn_R_50_FPN_3x' help += ',keypoint_rcnn_X_101_32x8d_FPN_3x' parser = argparse.ArgumentParser() parser.add_argument('--model', type=str, default = 'keypoint_rcnn_R_50_FPN_1x', help = help) parser.add_argument('--file', type=str, default = '') parser.add_argument('--size', type=str, default = '640X480', help = 'image inference size ex:320X240') opt = parser.parse_args() W, H = opt.size.split('X') if opt.file == '': url = 'http://images.cocodataset.org/val2017/000000439715.jpg' img = Image.open(requests.get(url, stream=True).raw).resize((int(W),int(H))) im = np.asarray(img, dtype="uint8") height, width, channels = im.shape if channels == 3: im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR) else: im = cv2.cvtColor(im, cv2.COLOR_RGBA2BGR) else: im = cv2.imread(opt.file, cv2.IMREAD_COLOR) height, width, channels = im.shape print('image W:%d H:%d'%(width, height)) network_model = 'COCO-Keypoints/' + opt.model + '.yaml' cfg = get_cfg() # add project-specific config (e.g., TensorMask) here if you're not running a model in detectron2's core library cfg.merge_from_file(model_zoo.get_config_file(network_model)) cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model # Find a model from detectron2's model zoo. You can use the https://dl.fbaipublicfiles... url as well cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(network_model) predictor = DefaultPredictor(cfg) for i in range (2): fps_time = time.perf_counter() outputs = predictor(im) fps = 1.0 / (time.perf_counter() - fps_time) if i == 1: print('===== output =====') print(outputs) print("Net FPS: %f" % (fps)) v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2) out = v.draw_instance_predictions(outputs["instances"].to("cpu")) cv2.imwrite("detectron2_poseestimation_%s_result.jpg"%(opt.model), out.get_image()[:, :, ::-1]) cv2.imwrite("./source_image.jpg", im)
<poseestimation.py>
Run the code.
root@jetpack-4:/usr/local/src/detectron2_test# python3 poseestimation.py --file='../test_images/peds_0.jpg' image W:1920 H:1080 ===== output ===== {'instances': Instances(num_instances=4, image_height=1080, image_width=1920, fields=[pred_boxes: Boxes(tensor([[1036.1870, 106.6064, 1403.8110, 973.7678], [ 441.3257, 123.6492, 735.9746, 978.6694], [ 745.8071, 183.1393, 1015.6016, 935.7772], [1483.9121, 160.6765, 1733.1891, 879.7750]], device='cuda:0')),
scores: tensor([0.9999, 0.9999, 0.9994, 0.9994], device='cuda:0'),
pred_classes: tensor([0, 0, 0, 0], device='cuda:0'),
pred_keypoints: tensor([[[1.2279e+03, 2.2812e+02, 1.2061e+00], [1.2466e+03, 2.1231e+02, 1.5694e+00], [1.2135e+03, 2.1087e+02, 1.2504e+00], [1.2739e+03, 1.8930e+02, 7.5815e-01], [1.1949e+03, 1.9073e+02, 6.5784e-01], [1.3284e+03, 2.7702e+02, 3.0210e-01], [1.1532e+03, 2.7558e+02, 5.2108e-01], [1.3543e+03, 4.0213e+02, 3.6574e-01], [1.1087e+03, 3.6906e+02, 4.3612e-01], [1.3643e+03, 5.2437e+02, 5.2245e-01], [1.0742e+03, 4.2802e+02, 4.0177e-01], [1.2839e+03, 5.4306e+02, 1.1480e-01], [1.1762e+03, 5.3731e+02, 1.2940e-01], [1.2566e+03, 7.2426e+02, 2.7323e-01], [1.1863e+03, 7.1276e+02, 2.8741e-01], [1.2322e+03, 8.9395e+02, 2.2146e-01], [1.2006e+03, 8.7670e+02, 1.3534e-01]], [[6.2027e+02, 2.0930e+02, 9.4037e-01], [6.3608e+02, 1.9490e+02, 1.0583e+00], [6.0446e+02, 1.9202e+02, 9.3918e-01], [6.5620e+02, 2.0066e+02, 1.4123e+00], [5.7859e+02, 1.9922e+02, 1.7396e+00], [6.8926e+02, 2.9422e+02, 2.8203e-01], [5.2397e+02, 2.8990e+02, 3.2651e-01], [7.1082e+02, 4.3097e+02, 3.6540e-01], [4.8516e+02, 4.0218e+02, 3.4584e-01], [7.0795e+02, 5.5332e+02, 6.2979e-01], [4.8229e+02, 4.9142e+02, 5.2058e-01], [6.4614e+02, 5.3317e+02, 1.1674e-01], [5.4266e+02, 5.3317e+02, 1.5501e-01], [6.0877e+02, 7.0734e+02, 2.0952e-01], [5.7428e+02, 7.0878e+02, 2.1423e-01], [5.8290e+02, 9.1749e+02, 3.0073e-01], [6.2458e+02, 8.4120e+02, 1.8838e-01]], [[8.6994e+02, 2.5006e+02, 2.4667e+00], [8.8573e+02, 2.3710e+02, 2.5618e+00], [8.5559e+02, 2.3710e+02, 2.4670e+00], [9.0582e+02, 2.5150e+02, 1.2572e+00], [8.3550e+02, 2.5150e+02, 1.3055e+00], [9.5318e+02, 3.3208e+02, 2.7176e-01], [7.9388e+02, 3.4360e+02, 2.9622e-01], [9.8475e+02, 4.4001e+02, 3.2787e-01], [7.7092e+02, 4.5728e+02, 2.9786e-01], [9.7327e+02, 5.4651e+02, 5.8070e-01], [7.6518e+02, 5.6665e+02, 3.4241e-01], [9.2878e+02, 5.6090e+02, 1.0988e-01], [8.2402e+02, 5.6090e+02, 1.0114e-01], [9.2591e+02, 7.0768e+02, 1.8532e-01], [8.2832e+02, 7.2495e+02, 2.0711e-01], [9.2878e+02, 8.1274e+02, 2.8305e-01], [8.3406e+02, 8.9764e+02, 2.3432e-01]], [[1.5863e+03, 2.2468e+02, 2.2018e-01], [1.5863e+03, 2.1317e+02, 2.2034e-01], [1.6580e+03, 2.2324e+02, 2.6580e-01], [1.5921e+03, 2.2180e+02, 7.7730e-01], [1.6551e+03, 2.2611e+02, 9.5655e-01], [1.5505e+03, 2.9946e+02, 2.8758e-01], [1.6938e+03, 2.9659e+02, 2.5435e-01], [1.5161e+03, 4.0877e+02, 3.5267e-01], [1.7024e+03, 3.9438e+02, 3.7222e-01], [1.5004e+03, 4.8499e+02, 4.2061e-01], [1.6494e+03, 4.0589e+02, 4.2102e-02], [1.5777e+03, 5.0369e+02, 6.9173e-02], [1.6608e+03, 5.1951e+02, 8.7721e-02], [1.5835e+03, 6.7483e+02, 2.1163e-01], [1.6351e+03, 6.6189e+02, 2.8551e-01], [1.6035e+03, 8.4310e+02, 1.3873e-01], [1.6279e+03, 8.1146e+02, 1.0684e-01]]], device='cuda:0')])} Net FPS: 0.194772 Segmentation fault (core dumped)
And you can get the result image like this.
<detectron2_mask_rcnn_R_50_FPN_1x_result.jpg>
We also confirmed that KeyPoint Detection works in the same way as Segmentation. Now let's try to handle the result outputs in our own way without using Detectron2's Visualizer.
Under the Hood
Let's take a look at the result value output by the print function above.
- num_instances : Number of person recognized
- image_height :Inference image height
- image_width :Inference image width
- pred_boxes : Recognized person's location coordinates
- scores :Probability of being a person
- pred_classes : Classification of predicted objects. KeyPoint Detection classifies only one person class. Therefore, only 0 is present in pred_classes.
- pred_keypoints : KeyPoints Information. a Tensor of shape (N, num_keypoint, 3). Each row in the last dimension is (x, y, score). Scores are larger than 0
The official description of these values can be found at https://detectron2.readthedocs.io/tutorials/models.html#model-output-format.
This python file checks the metadata of the KeypointDetection dataset.
from detectron2 import model_zoo from detectron2.data import MetadataCatalog from detectron2.config import get_cfg network_model = 'COCO-Keypoints/keypoint_rcnn_R_101_FPN_3x.yaml' cfg = get_cfg() # add project-specific config (e.g., TensorMask) here if you're not running a model in detectron2's core library cfg.merge_from_file(model_zoo.get_config_file(network_model)) cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model # Find a model from detectron2's model zoo. You can use the https://dl.fbaipublicfiles... url as well cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(network_model) cls = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes print('NetworkModel[%s] category:%d'%(network_model, len(cls))) print(cls) kname = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).keypoint_names print('keypoint_names') print(kname) krule = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).keypoint_connection_rules print('keypoint_connection_rules') print(krule)
<metadata_keypoint.py>
Run the code.
root@jetpack-4:/usr/local/src/detectron2_test# python3 metadata.py NetworkModel[COCO-Keypoints/keypoint_rcnn_R_101_FPN_3x.yaml] category:1 ['person'] keypoint_names ('nose', 'left_eye', 'right_eye', 'left_ear', 'right_ear', 'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow', 'left_wrist', 'right_wrist', 'left_hip', 'right_hip', 'left_knee', 'right_knee', 'left_ankle', 'right_ankle') keypoint_connection_rules [('left_ear', 'left_eye', (102, 204, 255)), ('right_ear', 'right_eye', (51, 153, 255)), ('left_eye', 'nose', (102, 0, 204)), ('nose', 'right_eye', (51, 102, 255)), ('left_shoulder', 'right_shoulder', (255, 128, 0)), ('left_shoulder', 'left_elbow', (153, 255, 204)), ('right_shoulder', 'right_elbow', (128, 229, 255)), ('left_elbow', 'left_wrist', (153, 255, 153)), ('right_elbow', 'right_wrist', (102, 255, 224)), ('left_hip', 'right_hip', (255, 102, 0)), ('left_hip', 'left_knee', (255, 255, 77)), ('right_hip', 'right_knee', (153, 255, 204)), ('left_knee', 'left_ankle', (191, 255, 128)), ('right_knee', 'right_ankle', (255, 195, 77))]
As you can see, the Keypoint detection model has only 1 class category, 17 keypoints, and 14 keypoint connections.
Now let's print out the KeyPoint value to the image location without using Detectron2's Visualizer.
from detectron2 import model_zoo from detectron2.engine import DefaultPredictor from detectron2.config import get_cfg from detectron2.utils.visualizer import Visualizer from detectron2.data import MetadataCatalog import cv2 import numpy as np import requests, sys, time, os from PIL import Image, ImageDraw import argparse COLORS = [(0, 45, 74), (85, 32, 98), (93, 69, 12), (49, 18, 55), (46, 67, 18), (30, 74, 93), (218, 0, 0), (0, 218, 0), (0, 0, 218),(218, 218, 0), (0, 218, 218), (218, 0, 218), (128, 0, 0), (0, 128, 0), (0, 0, 128),(128, 128, 0), (0, 128, 128)] help = 'Base-Keypoint-RCNN-FPN' help += ',keypoint_rcnn_R_101_FPN_3x' help += ',keypoint_rcnn_R_50_FPN_1x' help += ',keypoint_rcnn_R_50_FPN_3x' help += ',keypoint_rcnn_X_101_32x8d_FPN_3x' parser = argparse.ArgumentParser() parser.add_argument('--model', type=str, default = 'keypoint_rcnn_R_50_FPN_1x', help = help) parser.add_argument('--file', type=str, default = '') parser.add_argument('--size', type=str, default = '640X480', help = 'image inference size ex:320X240') opt = parser.parse_args() W, H = opt.size.split('X') if opt.file == '': url = 'http://images.cocodataset.org/val2017/000000439715.jpg' img = Image.open(requests.get(url, stream=True).raw).resize((int(W),int(H))) im = np.asarray(img, dtype="uint8") height, width, channels = im.shape if channels == 3: im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR) else: im = cv2.cvtColor(im, cv2.COLOR_RGBA2BGR) else: im = cv2.imread(opt.file, cv2.IMREAD_COLOR) height, width, channels = im.shape print('image W:%d H:%d'%(width, height)) network_model = 'COCO-Keypoints/' + opt.model + '.yaml' cfg = get_cfg() # add project-specific config (e.g., TensorMask) here if you're not running a model in detectron2's core library cfg.merge_from_file(model_zoo.get_config_file(network_model)) cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model # Find a model from detectron2's model zoo. You can use the https://dl.fbaipublicfiles... url as well cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(network_model) predictor = DefaultPredictor(cfg) kname = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).keypoint_names im2 = im.copy() font = cv2.FONT_HERSHEY_SIMPLEX # normal size sans-serif font outputs = predictor(im) kpersons = outputs["instances"].pred_keypoints #for kpoints in kperson[0]: for kperson in kpersons: print('==== person ====') for i in range(0, len(kperson)): kpoints = kperson[i].cpu() x = kpoints[0] y = kpoints[1] print('%-20s position (%f, %f)'%(kname[i], x,y)) cv2.circle(im2, (int(x), int(y)), 10, color = COLORS[i], thickness=4) cv2.putText(im2, kname[i], (int(x) - 20, int(y) - 10), font, fontScale = 1, color = COLORS[i], thickness = 2) cv2.imwrite("detectron2_PoseEstimation_%s_result.jpg"%(opt.model), im2)
<poseestimation2.py>
Run the code.
root@jetpack-4:/usr/local/src/detectron2_test# python3 poseestimation2.py --file='../test_images/peds_0.jpg' image W:1920 H:1080 ==== person ==== nose position (1227.897217, 228.124039) left_eye position (1246.565674, 212.305176) right_eye position (1213.536987, 210.867111) left_ear position (1273.850220, 189.295929) right_ear position (1194.868530, 190.734009) left_shoulder position (1328.419312, 277.018738) right_shoulder position (1153.223633, 275.580627) left_elbow position (1354.267944, 402.131531) right_elbow position (1108.706665, 369.055725) left_wrist position (1364.320190, 524.368225) right_wrist position (1074.241943, 428.016998) left_hip position (1283.902466, 543.063293) right_hip position (1176.200073, 537.310913) left_knee position (1256.617798, 724.261169) right_knee position (1186.252319, 712.756531) left_ankle position (1232.205322, 893.954407) right_ankle position (1200.612671, 876.697510) ==== person ==== nose position (620.270996, 209.295181) left_eye position (636.081421, 194.900925) right_eye position (604.460571, 192.022064) left_ear position (656.203857, 200.658615) right_ear position (578.588989, 199.219177) left_shoulder position (689.261963, 294.221436) right_shoulder position (523.971130, 289.903137) left_elbow position (710.821655, 430.967041) right_elbow position (485.163757, 402.178497) left_wrist position (707.947021, 553.318420) right_wrist position (482.289093, 491.423035) left_hip position (646.142639, 533.166443) right_hip position (542.656189, 533.166443) left_knee position (608.772522, 707.337158) right_knee position (574.277039, 708.776611) left_ankle position (582.900879, 917.493591) right_ankle position (624.582947, 841.204041) ==== person ==== nose position (869.941284, 250.056412) left_eye position (885.727112, 237.104706) right_eye position (855.590515, 237.104706) left_ear position (905.818115, 251.495499) right_ear position (835.499390, 251.495499) left_shoulder position (953.175659, 332.083893) right_shoulder position (793.882202, 343.596497) left_elbow position (984.747375, 440.014771) right_elbow position (770.920898, 457.283691) left_wrist position (973.266785, 546.506531) right_wrist position (765.180664, 566.653625) left_hip position (928.779419, 560.897278) right_hip position (824.018799, 560.897278) left_knee position (925.909241, 707.683228) right_knee position (828.323975, 724.952209) left_ankle position (928.779419, 812.735962) right_ankle position (834.064331, 897.641602) ==== person ==== nose position (1586.344971, 224.676224) left_eye position (1586.344971, 213.170654) right_eye position (1657.976074, 223.238022) left_ear position (1592.075439, 221.799850) right_ear position (1655.110840, 226.114426) left_shoulder position (1550.529297, 299.462494) right_shoulder position (1693.791870, 296.586090) left_elbow position (1516.146240, 408.765472) right_elbow position (1702.387573, 394.383484) left_wrist position (1500.387329, 484.989899) right_wrist position (1649.380371, 405.889069) left_hip position (1577.749146, 503.686462) right_hip position (1660.841309, 519.506653) left_knee position (1583.479614, 674.831909) right_knee position (1635.054077, 661.888123) left_ankle position (1603.536377, 843.101074) right_ankle position (1627.890991, 811.460571) Segmentation fault (core dumped)
And you can get the result image like this.
<detectron2_PoseEstimation_rcnn_R_50_FPN_1x_result.jpg>
Now you can freely use the keypoint location and name without the help of Detectron2's Visualizer.
Wrapping up
So far, I have looked at PoseEstimation supported by Detectron2 on Jetson Nano. As with PyTorch's PoseEstimation blog, which we looked at earlier, despite its excellent accuracy, it still has the drawback of being too slow for use with the Jetson Nano. I always use 10FPS as a criterion for actual availability.
The current speed is unreasonable for use with the Jetson Nano.
Perhaps it can be used on Xavier NX or higher hardware.
The current speed is unreasonable for use with the Jetson Nano.
Perhaps it can be used on Xavier NX or higher hardware.
Detectron2 has been reorganized to be fairly easy to use, and it is a system that will continue to evolve in FaceBook AI Research in the future.
And they promised to release models with improved processing speed based on R-CNN models such as FBNet, ShuffleNet, and MobileNet. Please check Detectron2's official github https://github.com/facebookresearch/detectron2 from time to time.
댓글 없음:
댓글 쓰기