2020년 6월 20일 토요일

Jetson Nano - Detectron2 Segmentation Models

Prerequisites

In a previous blog, you learned how to install Detectron2 on Jetson Nano.
 
 
 This time, I will look at the types and performance of object detection supported by Detectron2.
Detectron2's github page https://github.com/facebookresearch/detectron2 You can see the various object detection models supported by Detectron2 by looking in the config directory.


To check the models supported by Detectron2, be sure to check this page.

Segmentation Models

 

In semantic segmentation, the goal is to classify each pixel into the given classes. In instance segmentation, we care about segmentation of the instances of objects separately. The panoptic segmentation combines semantic and instance segmentation such that all pixels are assigned a class label and all object instances are uniquely segmented.

Read about semantic segmentation, and instance segmentation.

 

<semantic segmentation >


<instance segmentation >
 

<panoptic segmentation>

Detectron2 provides model information in the config/COCO-Panoptic Segmentation directory for panoptic segmentation. In addition, Detectron2 provides model information in the COCO-InstanceSegmentation and LVIS-InstanceSegmentation directories for instance segmentation.

Let's check each model's usage and features through examples.

COCO-InstanceSegmentation


from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
import cv2
import numpy as np
import requests, sys, time, os
from PIL import Image
import argparse

help = 'mask_rcnn_R_101_C4_3x, mask_rcnn_R_101_DC5_3x, mask_rcnn_R_101_FPN_3x'
help += ',mask_rcnn_R_50_C4_1x'
help += ',mask_rcnn_R_50_C4_3x'
help += ',mask_rcnn_R_50_DC5_1x'
help += ',mask_rcnn_R_50_DC5_3x'
help += ',mask_rcnn_R_50_FPN_1x'
help += ',mask_rcnn_R_50_FPN_1x_giou'
help += ',mask_rcnn_R_50_FPN_3x'
help += ',mask_rcnn_X_101_32x8d_FPN_3x'

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default = 'mask_rcnn_R_50_FPN_3x', help = help)
parser.add_argument('--size', type=str, default = '640X480', help = 'image inference size ex:320X240')

opt = parser.parse_args()

W, H = opt.size.split('X')
url = 'http://images.cocodataset.org/val2017/000000439715.jpg'

img = Image.open(requests.get(url, stream=True).raw).resize((int(W),int(H)))
im = np.asarray(img, dtype="uint8")
height, width, channels = im.shape
if channels == 3:
    im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
else:
    im = cv2.cvtColor(im, cv2.COLOR_RGBA2BGR)

print('image W:%d H:%d'%(width, height))

network_model = 'COCO-InstanceSegmentation/' + opt.model + '.yaml'


cfg = get_cfg()
# add project-specific config (e.g., TensorMask) here if you're not running a model in detectron2's core library
cfg.merge_from_file(model_zoo.get_config_file(network_model))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5  # set threshold for this model
# Find a model from detectron2's model zoo. You can use the https://dl.fbaipublicfiles... url as well
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(network_model)
predictor = DefaultPredictor(cfg)


for i in range (2):
    fps_time  = time.perf_counter()
    outputs = predictor(im)
    fps = 1.0 / (time.perf_counter() - fps_time)
    print("Net FPS: %f" % (fps))
    v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2)
    out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
    
cv2.imwrite("detectron2_%s_result.jpg"%(opt.model), out.get_image()[:, :, ::-1])
cv2.imwrite("./source_image.jpg", im)
<instance_segmentation.py>

The above code can be tested by varying the size of the inference model and image using the --model and --size parameters like this.

root@jetpack-4:/usr/local/src/detectron2# python3 instance_segmentation.py --model=mask_rcnn_R_50_FPN_1x
image W:640 H:480
model_final_a54504.pkl: 178MB [00:24, 7.26MB/s]
Net FPS: 0.137043
Net FPS: 0.239866

root@jetpack-4:/usr/local/src/detectron2# python3 instance_segmentation.py --model=mask_rcnn_R_50_FPN_3x
image W:640 H:480
Net FPS: 0.091935
Net FPS: 0.232827

The following is a summary of execution times applying various models.

 Model  
First FPS
Second FPS
 memo
 mask_rcnn_R_50_FPN_1x 0.137043  
 0.239866 
 
 mask_rcnn_R_50_FPN_3x 0.091935 0.232827 
 mask_rcnn_R_50_DC5_3x 0.041882 0.140059 
 mask_rcnn_X_101_32x8d_FPN_3x 0.057355 0.093649 
 mask_rcnn_R_50_DC5_1x 0.032277 0.122482 
 mask_rcnn_R_101_DC5_3x 0.034573 0.116745 
 mask_rcnn_R_101_FPN_3x 0.069153 0.187640 
 mask_rcnn_R_50_C4_3x  process killed
 mask_rcnn_R_50_FPN_1x_giou  RuntimeError         
 mask_rcnn_R_50_C4_1x  process killed
 mask_rcnn_R_101_C4_3x  process killed

And the following figure is the output image for each model used in the table above.



Under the hood

Now let's look at the values ​​in the model's output. Modify the above code a little and print the output value using the print statement.


for i in range (2):
    fps_time  = time.perf_counter()
    outputs = predictor(im)
    fps = 1.0 / (time.perf_counter() - fps_time)
    if i == 1:
        print('===== pred_boxes =====')
        print(outputs["instances"].pred_boxes)

        print('===== scores =====')
        print(outputs["instances"].scores)

        print('===== pred_classes =====')
        print(outputs["instances"].pred_classes)

        print('===== pred_masks shape=====')
        print(np.shape(outputs["instances"].pred_masks))


    print("Net FPS: %f" % (fps))


    v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2)
    out = v.draw_instance_predictions(outputs["instances"].to("cpu"))

root@jetpack-4:/usr/local/src/detectron2# python3 instance_segmentation.py --model=mask_rcnn_R_50_FPN_3x
image W:640 H:480
Net FPS: 0.091935
Net FPS: 0.232827


And if you run it newly, you will get the following output.

root@jetpack-4:/usr/local/src/detectron2# python3 instance_segmentation.py --model=mask_rcnn_R_101_FPN_3x
image W:640 H:480
Net FPS: 0.072939
===== pred_boxes =====
Boxes(tensor([[134.7698, 244.0194, 475.4960, 479.9181],
        [114.9002, 270.0168, 147.6761, 394.1597],
        [562.5130, 271.3448, 596.8009, 380.1035],
        [253.5800, 165.3687, 334.2534, 404.7804],
        [ 48.8288, 275.8104,  79.0881, 347.1985],
        [516.7025, 281.0125, 561.6487, 344.8624],
        [  0.7477, 274.8362,  78.6029, 474.8378],
        [385.6545, 272.2646, 411.7497, 302.4032],
        [329.5269, 231.0193, 393.1201, 256.7559],
        [405.9588, 273.3003, 459.8397, 359.6782],
        [510.3246, 266.7764, 571.3597, 287.0667],
        [330.4257, 251.8549, 415.3971, 276.1583],
        [344.5151, 269.3447, 383.9334, 298.3615]], device='cuda:0'))
===== scores =====
tensor([0.9992, 0.9969, 0.9922, 0.9920, 0.9905, 0.9737, 0.9641, 0.9495, 0.9476,
        0.9309, 0.9207, 0.8759, 0.7874], device='cuda:0')
===== pred_classes =====
tensor([17,  0,  0,  0,  0,  0,  0,  0, 25,  0, 25, 25,  0], device='cuda:0')
===== pred_masks shape=====
torch.Size([13, 480, 640])
Net FPS: 0.189391

Investigating these results leads to the following conclusions.
The outputs of the result have a dictionary type value and have the following important values.
  • "pred_boxs": Box coordinate list of recognized objects of output image. In the above result, there are 13 boxes. That is, 13 objects were recognized and the coordinates of the positions of the objects were indicated.
  • "scores": It shows the recognition rate of 13 objects. Since we initially used a threshold value of 0.5, objects with values ​​below 0.5 are not included.
  • "pred_classes": It is an label index value representing 13 types of objects.
  • "pred_masks": The pixel value information of 13 objects is stored in 13 black/white images. The image size is the same as the size of the inference image. Detectron2 has the format of [channel, height, width].
For reference, the COCO label index values ​​are as follows.  Mask R-CNN uses the COCO class. Mask R-CNN model names 50 and 101 indicate the number of layers in the network model. In general, if the number of layers is large, the result is slightly more accurate, but there is a disadvantage that the speed is slow. You can apply index 0 from person.

class_names = ['BG', 'person', 'bicycle', 'car', 'motorcycle', 'airplane',
               'bus', 'train', 'truck', 'boat', 'traffic light',
               'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird',
               'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear',
               'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
               'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
               'kite', 'baseball bat', 'baseball glove', 'skateboard',
               'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
               'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
               'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
               'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
               'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
               'keyboard', 'cell phone', 'microwave', 'oven', 'toaster',
               'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
               'teddy bear', 'hair drier', 'toothbrush']


Documents about the output of the model can be found at https://detectron2.readthedocs.io/tutorials/models.html#model-output-format.

Model Output Format

When in training mode, the builtin models output a dict[str->ScalarTensor] with all the losses.

When in inference mode, the builtin models output a list[dict], one dict for each image. Based on the tasks the model is doing, each dict may contain the following fields:

  • “instances”: Instances object with the following fields:

    • “pred_boxes”: Boxes object storing N boxes, one for each detected instance.

    • “scores”: Tensor, a vector of N scores.

    • “pred_classes”: Tensor, a vector of N labels in range [0, num_categories).

    • “pred_masks”: a Tensor of shape (N, H, W), masks for each detected instance.

    • “pred_keypoints”: a Tensor of shape (N, num_keypoint, 3). Each row in the last dimension is (x, y, score). Scores are larger than 0.

  • “sem_seg”: Tensor of (num_categories, H, W), the semantic segmentation prediction.

  • “proposals”: Instances object with the following fields:

    • “proposal_boxes”: Boxes object storing N boxes.

    • “objectness_logits”: a torch vector of N scores.

  • “panoptic_seg”: A tuple of (Tensor, list[dict]). The tensor has shape (H, W), where each element represent the segment id of the pixel. Each dict describes one segment id and has the following fields:

    • “id”: the segment id

    • “isthing”: whether the segment is a thing or stuff

    • “category_id”: the category id of this segment. It represents the thing class id when isthing==True, and the stuff class id otherwise.


Display only box data

Above we looked at some important outputs. Among them, pred_boxs indicate the location of the box of the object location. In the above example, I used Detectron2's Visualizer to draw.
This time, let's use the OpenCV function to express only the information we want without using a visualizer.


from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
import cv2
import numpy as np
import requests, sys, time, os
from PIL import Image, ImageDraw
import argparse


COLORS = [(0, 45, 74, 127), (85, 32, 98, 127), (93, 69, 12, 127),
          (49, 18, 55, 127), (46, 67, 18, 127), (30, 74, 93, 127)]
LINE_COLORS = [(0, 45, 74, 255), (85, 32, 98, 255), (93, 69, 12, 255),
          (49, 18, 55, 255), (46, 67, 18, 255), (30, 74, 93, 255)]
help = 'mask_rcnn_R_101_C4_3x, mask_rcnn_R_101_DC5_3x, mask_rcnn_R_101_FPN_3x'
help += ',mask_rcnn_R_50_C4_1x'
help += ',mask_rcnn_R_50_C4_3x'
help += ',mask_rcnn_R_50_DC5_1x'
help += ',mask_rcnn_R_50_DC5_3x'
help += ',mask_rcnn_R_50_FPN_1x'
help += ',mask_rcnn_R_50_FPN_1x_giou'
help += ',mask_rcnn_R_50_FPN_3x'
help += ',mask_rcnn_X_101_32x8d_FPN_3x'

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default = 'mask_rcnn_R_50_FPN_3x', help = help)
parser.add_argument('--file', type=str, default = '')
parser.add_argument('--size', type=str, default = '640X480', help = 'image inference size ex:320X240')

opt = parser.parse_args()

W, H = opt.size.split('X')
if opt.file == '':
    url = 'http://images.cocodataset.org/val2017/000000439715.jpg'
    img = Image.open(requests.get(url, stream=True).raw).resize((int(W),int(H)))
    im = np.asarray(img, dtype="uint8")
    height, width, channels = im.shape
    if channels == 3:
        im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
    else:
        im = cv2.cvtColor(im, cv2.COLOR_RGBA2BGR)
    
else:
    im = cv2.imread(opt.file, cv2.IMREAD_COLOR)
    height, width, channels = im.shape
    

print('image W:%d H:%d'%(width, height))

network_model = 'COCO-InstanceSegmentation/' + opt.model + '.yaml'


cfg = get_cfg()
# add project-specific config (e.g., TensorMask) here if you're not running a model in detectron2's core library
cfg.merge_from_file(model_zoo.get_config_file(network_model))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5  # set threshold for this model
# Find a model from detectron2's model zoo. You can use the https://dl.fbaipublicfiles... url as well
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(network_model)
predictor = DefaultPredictor(cfg)


fps_time  = time.perf_counter()
outputs = predictor(im)
fps = 1.0 / (time.perf_counter() - fps_time)
print('===== pred_boxes =====')
print(outputs["instances"].pred_boxes)

print('===== scores =====')
print(outputs["instances"].scores)

print('===== pred_classes =====')
print(outputs["instances"].pred_classes)

print('===== pred_masks len=====')
print(np.shape(outputs["instances"].pred_masks))

print("Net FPS: %f" % (fps))


'''
I'm going to use PIL draw, because it's very easy to draw alpha rectangle
'''

im2 = Image.fromarray(np.uint8(im))
drw = ImageDraw.Draw(im2, 'RGBA')
pred_boxes = outputs["instances"].pred_boxes
classes = outputs["instances"].pred_classes

color_index = 0
for cs, box in zip(classes, pred_boxes):
    box = box.cpu()
    cs = cs.cpu()
    x0, y0, x1, y1 = box
    color = COLORS[color_index % len(COLORS)]
linecolor = LINE_COLORS[color_index % len(LINE_COLORS)]
  color_index += 1
drw.rectangle([x0, y0, x1, y1], outline = linecolor, width=5) drw.rectangle([x0, y0, x1, y1], fill = color, width=5) im2.save("detectron2_box_%s_result.jpg"%(opt.model))
<instance_segmentation_box.py>


If you run the python code, you can get an image file like this. I drew only the box values ​​from the output information using the PIL function.
<detectron2_box_mask_rcnn_R_50_FPN_3x_result.jpg>


Display only mask data

Drawing mask data is a bit tricky. You need to be familiar with OpenCV or PIL's image processing capabilities.

from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
import cv2
import numpy as np
import requests, sys, time, os
from PIL import Image, ImageDraw
import argparse


def process_alpha_masking(base, mask, color):
    h, w, _ = base.shape
    bg = base.copy()
    R = color[0]
    G = color[1]
    B = color[2]
    alpha = color[3] / 255
    
    try:
        for i in range(0, h):
            for j in range(0, w):
                val = mask[i][j]
                if val == True:
                    bg[i][j][0] = int(B * alpha + bg[i][j][0] * (1 - alpha))
                    bg[i][j][1] = int(G * alpha + bg[i][j][1] * (1 - alpha))
                    bg[i][j][2] = int(R * alpha + bg[i][j][2] * (1 - alpha))
    except IndexError:  #index (i, j) is out of the screen resolution.  
        print(' index Error')
        return None
    return bg


COLORS = [(0, 45, 74, 224), (85, 32, 98, 224), (93, 69, 12, 224),
          (49, 18, 55, 224), (46, 67, 18, 224), (30, 74, 93, 224)]

help = 'mask_rcnn_R_101_C4_3x, mask_rcnn_R_101_DC5_3x, mask_rcnn_R_101_FPN_3x'
help += ',mask_rcnn_R_50_C4_1x'
help += ',mask_rcnn_R_50_C4_3x'
help += ',mask_rcnn_R_50_DC5_1x'
help += ',mask_rcnn_R_50_DC5_3x'
help += ',mask_rcnn_R_50_FPN_1x'
help += ',mask_rcnn_R_50_FPN_1x_giou'
help += ',mask_rcnn_R_50_FPN_3x'
help += ',mask_rcnn_X_101_32x8d_FPN_3x'

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default = 'mask_rcnn_R_50_FPN_3x', help = help)
parser.add_argument('--file', type=str, default = '')
parser.add_argument('--size', type=str, default = '640X480', help = 'image inference size ex:320X240')

opt = parser.parse_args()

W, H = opt.size.split('X')
if opt.file == '':
    url = 'http://images.cocodataset.org/val2017/000000439715.jpg'
    img = Image.open(requests.get(url, stream=True).raw).resize((int(W),int(H)))
    im = np.asarray(img, dtype="uint8")
    height, width, channels = im.shape
    if channels == 3:
        im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
    else:
        im = cv2.cvtColor(im, cv2.COLOR_RGBA2BGR)
    
else:
    im = cv2.imread(opt.file, cv2.IMREAD_COLOR)
    height, width, channels = im.shape
    

print('image W:%d H:%d'%(width, height))

network_model = 'COCO-InstanceSegmentation/' + opt.model + '.yaml'


cfg = get_cfg()
# add project-specific config (e.g., TensorMask) here if you're not running a model in detectron2's core library
cfg.merge_from_file(model_zoo.get_config_file(network_model))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5  # set threshold for this model
# Find a model from detectron2's model zoo. You can use the https://dl.fbaipublicfiles... url as well
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(network_model)
predictor = DefaultPredictor(cfg)


fps_time  = time.perf_counter()
outputs = predictor(im)
fps = 1.0 / (time.perf_counter() - fps_time)
print('===== pred_boxes =====')
print(outputs["instances"].pred_boxes)

print('===== scores =====')
print(outputs["instances"].scores)

print('===== pred_classes =====')
print(outputs["instances"].pred_classes)

print('===== pred_masks len=====')
print(np.shape(outputs["instances"].pred_masks))

print("Net FPS: %f" % (fps))


'''
This time,  I'm going to use OpenCV functions
'''

im2 = im.copy()
pred_masks = outputs["instances"].pred_masks
classes = outputs["instances"].pred_classes

color_index = 0
for cs, mask in zip(classes, pred_masks):
    mask = mask.cpu()
    cs = cs.cpu()
    color = COLORS[color_index % len(COLORS)]
    color_index += 1
    im2 = process_alpha_masking(im2, mask, color)
    print("Mask Processing: %d" % (color_index))

cv2.imwrite("detectron2_mask_%s_result.jpg"%(opt.model), im2)
<instance_segmentation_mask.py>

If you run the python code, you can get an image file like this. I drew only the mask values ​​from the output information.  Since every pixel in the image needs to be compared with the mask data, the number of operations is the number of recognized objects X image horizontal size X image vertical size. Therefore, it takes a long time to compare all the mask data. Detectron2's Visualizer is relatively fast. Probably, Detectron2's Visualizer processed the image processing part in C. In Python, large iterations of code have a large impact on performance.

<detectron2_mask_mask_rcnn_R_50_FPN_3x_result.jpg>


COCO-PanopticSegmentation

Now, I will look at COCO-PanopticSegmentation. You can check the currently supported models at https://github.com/facebookresearch/detectron2/tree/master/configs/COCO-PanopticSegmentation page. The usage of panoptic segmentation is exactly the same as the instance segmentation described earlier. Therefore, you can use the network model in the same way as the above.

As explained earlier in the article, Panoptic segmentation is a combination of semantic segmentation and instance segmentation. However, when this model was run, semantic segmentation results could not be found.

<detectron2_panoptic_mask_rcnn_R_50_FPN_3x_result.jpg>

LVIS(Large Vocabulary Instance Segmentation)-Segmentation

LVIS is a model designed to classify over 1200 categories using 164K COCO images. Currently, it is classified as 1,230 accurately. Unlike the Panoptic segmentation or Instance segmentation, as shown in the following figure, it distinguishes very specific objects such as jackets, buttons, ties, and jean.




How can I get the Class name of each models?

The various models provided by Detectron2 Segmentation provide different types of class classification.
For example, LVIS-InstanceSegmentation provides 1230 class classification. For example, LVIS-InstanceSegmentation provides 1230 class classification. And Cityscapes offers eight categories.

Using Detectron2's Metadata, you can easily check various information including this value.
Official documentation on Detectron2's use of Metada is available at https://detectron2.readthedocs.io/tutorials/datasets.html#metadata-for-datasets.


from detectron2 import model_zoo
from detectron2.data import MetadataCatalog
from detectron2.config import get_cfg

network_model = 'COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml'

cfg = get_cfg()
# add project-specific config (e.g., TensorMask) here if you're not running a model in detectron2's core library
cfg.merge_from_file(model_zoo.get_config_file(network_model))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5  # set threshold for this model
# Find a model from detectron2's model zoo. You can use the https://dl.fbaipublicfiles... url as well
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(network_model)

cls = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes
print('NetworkModel[%s] category:%d'%(network_model, len(cls)))
print(cls)

network_model = 'Cityscapes/mask_rcnn_R_50_FPN.yaml'
cfg = get_cfg()
# add project-specific config (e.g., TensorMask) here if you're not running a model in detectron2's core library
cfg.merge_from_file(model_zoo.get_config_file(network_model))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5  # set threshold for this model
# Find a model from detectron2's model zoo. You can use the https://dl.fbaipublicfiles... url as well
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(network_model)

cls = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes
print('NetworkModel[%s] category:%d'%(network_model, len(cls)))
print(cls)

network_model = 'LVIS-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml'
cfg = get_cfg()
# add project-specific config (e.g., TensorMask) here if you're not running a model in detectron2's core library
cfg.merge_from_file(model_zoo.get_config_file(network_model))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5  # set threshold for this model
# Find a model from detectron2's model zoo. You can use the https://dl.fbaipublicfiles... url as well
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(network_model)

cls = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes
print('NetworkModel[%s] category:%d'%(network_model, len(cls)))
#Too many calsses to print
#print(cls)
<metadata.py>

Run the code and check the result.

root@jetpack-4:/usr/local/src/detectron2_test# python3 metadata.py
NetworkModel[COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml] category:80
['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']

NetworkModel[Cityscapes/mask_rcnn_R_50_FPN.yaml] category:8 ['person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', 'bicycle']
NetworkModel
[LVIS-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml] category:1230

As you can see from the above results, you can easily check the number of classes and names supported by the model you want to use.
MetadataCatalog provides various meta information in addition to thing_classes that show classifiable class types. Information such as keypoint_names and keypoint_flip_map can be usefully used in PoseEstimation of Detectron2, which will be introduced next.

Wrapping up

The previous tests show that Detectron2's segmentation accuracy is excellent. However, it has the disadvantage of being too slow to process on the Jetson Nano. The main reason for the slowness is that the Detectron2 uses heavy models with high accuracy.
In the next post, I'll cover the Keypoint detection and DensePose of Detectron2 that I'm most interested in.
You can download the source code at https://github.com/raspberry-pi-maker/NVIDIA-Jetson .











댓글 없음:

댓글 쓰기