In the previous article (Xavier NX - YOLOv8 Video Object Detection (JetPack 5.1) ), I introduced the video_detect_cv.py example.
In this example, even though yolov8n.pt, the lightest of the YOLOV8 models, was used, the Net FPS value was only 18.27 in Xavier NX.
If you change the model to yolov8s.pt, yolov8m.pt, or yolov8l.pt, the accuracy will go up, but the speed will be much lower.
In fact, after changing the model to yolo8m.pt and testing, the Net FPS value dropped to 9.07. This value is difficult to use in a production environment.
On platforms with NVidia GPUs, the best way to get the most speed is to use TensorRT. Fortunately, the Jetson series we use already has cuda and TensorRT installed.
In this article, we will learn how to convert a YOLOv8 model to a TensorRT model, how to use it, and how to improve performance.
TensorRT for Anaconda python
First, install TensorRT for Python on XavierNX.
TensorRT is installed on the Jetson series, but there are no bindings for Python.
So, install the python bindings with the following command:
sudo apt-get install python3-libnvinfer
sudo apt-get install python3-libnvinfer-dev
One thing to note is that because you use the apt command, the TensorRT package for Python is installed in Xavier NX's default Python, not in a virtual environment.
You can check it like this:
spypiggy@spypiggy-NX:/usr/lib/python3.8/dist-packages$ ll
total 44
drwxr-xr-x 6 root root 4096 6월 7 15:34 ./
drwxr-xr-x 32 root root 20480 6월 7 14:56 ../
drwxrwxr-x 7 root root 4096 1월 26 13:12 cv2/
drwxr-xr-x 6 root root 4096 5월 19 2021 numpy/
drwxr-xr-x 2 root root 4096 6월 7 15:34 tensorrt/
drwxr-xr-x 2 root root 4096 6월 7 15:34 tensorrt-8.5.2.2.dist-info/
I will copy these packages into the virtual environment. For reference, the name of the virtual environment I am using is "yolov8".
Since the package was installed with the sudo apt command, the owner of the package is root. Therefore, I will change the ownership to the user account spypiggy in the virtual environment.
(base) spypiggy@spypiggy-NX:~/anaconda3/envs/yolov8/lib/python3.8/site-packages$ sudo cp -r /usr/lib/python3.8/dist-packages/tensorrt ./ (base) spypiggy@spypiggy-NX:~/anaconda3/envs/yolov8/lib/python3.8/site-packages$ sudo cp -r /usr/lib/python3.8/dist-packages/tensorrt-8.5.2.2.dist-info/ ./ (base) spypiggy@spypiggy-NX:~/anaconda3/envs/yolov8/lib/python3.8/site-packages$ sudo chown spypiggy:spypiggy tensorrt (base) spypiggy@spypiggy-NX:~/anaconda3/envs/yolov8/lib/python3.8/site-packages$ sudo chown spypiggy:spypiggy tensorrt-8.5.2.2.dist-info/
Now let's check if tensorrt can be imported in the virtual environment.
(base) spypiggy@spypiggy-NX:~/anaconda3/envs/yolov8/lib/python3.8/site-packages$ conda activate yolov8 (yolov8) spypiggy@spypiggy-NX:~/anaconda3/envs/yolov8/lib/python3.8/site-packages$ python Python 3.8.16 (default, Mar 2 2023, 03:16:31) [GCC 11.2.0] :: Anaconda, Inc. on linux Type "help", "copyright", "credits" or "license" for more information. >>> import tensorrt >>>
I confirmed it works fine. Now, let's start converting the YOLOV8 model to TensorRT in earnest.
YOLOV8 to onnx
The Ultralytics home page lists options for adapting the YOLOv8 model to various frameworks. But when I tested it, the conversion to TensorRT didn't go well.
For reference, since TensorRT is version-sensitive, it is better to change to TensorRT directly in XavierNX.
Therefore, I decided to proceed in the order of changing the YOLOv8 model to onnx format first and then to TensorRT.
I will be using triple-Mu's github. Triple-Mu's program not only changes to onnx, but also adds a bbox decoder and NMS to the onnx model in one step. The difference from using the API provided by Ultralytics will be shown later.
(base) spypiggy@spypiggy-NX:~/src$ git clone https://github.com/triple-Mu/YOLOv8-TensorRT
(base) spypiggy@spypiggy-NX:~/src$ conda activate yolov8 (yolov8) spypiggy@spypiggy-NX:~/src$ cd YOLOv8-TensorRT/ (yolov8) spypiggy@spypiggy-NX:~/src/YOLOv8-TensorRT$ ls -al total 104 drwxrwxr-x 9 spypiggy spypiggy 4096 6월 7 16:55 . drwxrwxr-x 9 spypiggy spypiggy 4096 6월 7 18:48 .. -rw-rw-r-- 1 spypiggy spypiggy 1912 6월 7 16:37 build.py -rw-rw-r-- 1 spypiggy spypiggy 1817 6월 7 16:37 config.py drwxrwxr-x 7 spypiggy spypiggy 4096 6월 7 16:37 csrc drwxrwxr-x 2 spypiggy spypiggy 4096 6월 7 16:37 data drwxrwxr-x 2 spypiggy spypiggy 4096 6월 7 16:37 docs -rwxrwxr-x 1 spypiggy spypiggy 3138 6월 7 16:37 export-det.py -rw-rw-r-- 1 spypiggy spypiggy 2302 6월 7 16:37 export-seg.py -rw-rw-r-- 1 spypiggy spypiggy 1307 6월 7 16:37 gen_pkl.py drwxrwxr-x 8 spypiggy spypiggy 4096 6월 7 16:37 .git -rw-rw-r-- 1 spypiggy spypiggy 1862 6월 7 16:37 .gitignore -rw-rw-r-- 1 spypiggy spypiggy 2716 6월 7 17:33 infer-det.py -rw-rw-r-- 1 spypiggy spypiggy 2652 6월 7 16:37 infer-det-without-torch.py -rw-rw-r-- 1 spypiggy spypiggy 3867 6월 7 16:37 infer-seg.py -rw-rw-r-- 1 spypiggy spypiggy 3646 6월 7 16:37 infer-seg-without-torch.py -rw-rw-r-- 1 spypiggy spypiggy 1065 6월 7 16:37 LICENSE drwxrwxr-x 3 spypiggy spypiggy 4096 6월 7 16:39 models drwxrwxr-x 2 spypiggy spypiggy 4096 6월 7 17:39 output -rw-rw-r-- 1 spypiggy spypiggy 646 6월 7 16:37 .pre-commit-config.yaml drwxrwxr-x 2 spypiggy spypiggy 4096 6월 7 16:39 __pycache__ -rw-rw-r-- 1 spypiggy spypiggy 8238 6월 7 16:37 README.md -rw-rw-r-- 1 spypiggy spypiggy 105 6월 7 16:37 requirements.txt -rw-rw-r-- 1 spypiggy spypiggy 767 6월 7 16:37 trt-profile.py
To change the model, use export-det.py. The usage is well explained in the readme.md file on github. You can change the model path of yoloV8 marked in red to suit your environment.
python export-det.py \ --weights ../yolov8/yolov8s.pt \ --iou-thres 0.65 \ --conf-thres 0.15 \ --topk 100 \ --opset 11 \ --sim \ --input-shape 1 3 640 640 \ --device cuda:0
After a while, you can see that the onnx file has been created as follows.
(base) spypiggy@spypiggy-NX:~/src/yolov8$ ls -al yolov8s* -rw-rw-r-- 1 spypiggy spypiggy 44777438 6월 7 17:08 yolov8s.onnx -rw------- 1 spypiggy spypiggy 22573363 3월 14 22:56 yolov8s.pt
For reference, there is a difference between the one created using the Ultralytics API and the onnx model created above as shown in the following figure.
If you load the onnx model from the https://netron.app/ page, you can check the network configuration of the model as a graph as shown above. If you use the program of triple-Mu, you can create an onnx model so that you can immediately know the final result.
onnx to TensorRT
Now it's time to change the onnx model to a TensorRT model. You can use the trtexec program provided by TensorRT or the Python program provided by triple-Mu.
TensorRT recommends using fp16 instead of fp32 in many cases. GPU memory usage can be reduced and processing speed is also improved. Instead, a slight loss occurs in the recognition rate.
python3 build.py \ --weights ../yolov8/yolov8s.onnx \ --iou-thres 0.65 \ --conf-thres 0.15 \ --topk 100 \ --fp16 \ --device cuda:0
If you want to use trtexec provided by TensorRT, you must compile the source code as follows.
base) spypiggy@spypiggy-NX:/usr/src/tensorrt/samples/trtexec$ pwd /usr/src/tensorrt/samples/trtexec (base) spypiggy@spypiggy-NX:/usr/src/tensorrt/samples/trtexec$ ls -al total 52 drwxr-xr-x 2 root root 4096 1월 26 13:10 . drwxr-xr-x 14 root root 4096 1월 26 13:10 .. -rw-r--r-- 1 root root 223 12월 6 2022 Makefile -rwxr-xr-x 1 root root 2343 12월 6 2022 prn_utils.py -rwxr-xr-x 1 root root 6029 12월 6 2022 profiler.py -rw-r--r-- 1 root root 8798 12월 6 2022 README.md -rwxr-xr-x 1 root root 4054 12월 6 2022 tracer.py -rw-r--r-- 1 root root 12057 12월 6 2022 trtexec.cpp (base) spypiggy@spypiggy-NX:/usr/src/tensorrt/samples/trtexec$ sudo make //don't have to do "make install" (base) spypiggy@spypiggy-NX:/usr/src/tensorrt$ ll bin total 10340 drwxr-xr-x 4 root root 4096 6월 7 12:49 ./ drwxr-xr-x 5 root root 4096 1월 26 13:10 ../ drwxr-xr-x 3 root root 4096 6월 7 12:47 chobj/ drwxr-xr-x 3 root root 4096 6월 7 12:46 dchobj/ -rwxr-xr-x 1 root root 2373392 6월 7 12:49 trtexec* -rwxr-xr-x 1 root root 8194704 6월 7 12:47 trtexec_debug*
If the build succeeds, you can convert it like this:
(yolov8) spypiggy@spypiggy-NX:~/src/yolov8$ /usr/src/tensorrt/bin/trtexec \ --onnx=yolov8s.onnx \ --saveEngine=yolov8n.engine \ --fp16 \ --memPoolSize=workspace:4000
For reference, Xavier NX reduces the workspace because it has insufficient memory compared to PC.
If the yolov8s.engine file is created after a while, it is success.
Test TensorRT model
You can test it simply with the following command:
python infer-det.py --engine ../yolov8/yolov8s.engine \ --imgs data \ --show \ --out-dir outputs \ --device cuda:0
If you do not use the --show option, the results of using the TensorRT model will be saved to the outputs directory. If you use the --show option, you can see the following images on the screen.
Comparison of YOLOV8 and TensorRT
Now let's compare the performance of the YOLOV8 model and the model converted to TensorRT.
Accuracy
The two images below are the files created using the YOLOV8 command.
(yolov8) spypiggy@spypiggy-NX:~/src/yolov8$ yolo predict model=yolov8s.pt source='https://ultralytics.com/images/bus.jpg' (yolov8) spypiggy@spypiggy-NX:~/src/yolov8$ yolo predict model=yolov8s.pt source='https://ultralytics.com/images/zidane.jpg'
The exact difference in recognition rate can be known by obtaining F1 scores using a lot of verification data, but it seems that there is no big difference through the above test.
Performance
The main reason we use TensorRT is because of its fast processing speed. Now let's look at how the processing speed changes when the existing yolov8 model is converted to a TensorRT model.
First, the speed of the existing yolo model will be measured using video_detect_cv.py, which was used in Xavier NX - YOLOv8 Video Object Detection (JetPack 5.1). And I will use yolov8s model.
As I always do, I will discard the first inference result and use it to measure the speed from the second inference. And a video highway_traffic to use for testing. mp4 has a frame size of 1024 X 576.
And for accurate speed measurement, the video screen output was omitted. Instead, the result was saved as another video.
from ultralytics import YOLO import cv2 import time, sys import torchvision import torchvision.transforms as T colors = [(255,0 , 0), (0,255,0), (0,0,255)] font = cv2.FONT_HERSHEY_SIMPLEX fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') def draw(img, boxes): index = 0 for box in boxes.data: p1 = (int(box[0].item()), int(box[1].item())) p2 = (int(box[2].item()), int(box[3].item())) img = cv2.rectangle(img, p1, p2, colors[index % len(colors)], 3) text = label_map[int(box[5].item())] + " %4.2f"%(box[4].item()) cv2.putText(img, text, (p1[0], p1[1] - 10), font, fontScale = 1, color = colors[index % len(colors)], thickness = 2) index += 1 # cv2.imshow("draw", img) # cv2.waitKey(1) # Load a model model = YOLO("yolov8s.pt") # load an official model label_map = model.names f = 0 net_total = 0.0 total = 0.0 cap = cv2.VideoCapture("./highway_traffic.mp4") # Skip First frame ret, img = cap.read() if ret == False: print('Video File Read Error') sys.exit(0) h, w, c = img.shape print('Video Frame shape H:%d, W:%d, Channel:%d'%(h, w, c)) fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') out_video = cv2.VideoWriter('./cv_result.mp4', fourcc, cap.get(cv2.CAP_PROP_FPS), (w, h)) results = model(img) # predict on an image while cap.isOpened(): s = time.time() ret, img = cap.read() if ret == False: break net_s = time.time() results = model(img) # predict on an image net_e = time.time() for result in results: draw(result.orig_img, result.boxes) e = time.time() net_total += (net_e - net_s) total += (e - s) f += 1out_video.write(result.orig_img))fps = f / total net_fps = f / net_total print("Total processed frames:%d"%f) print("FPS:%4.2f"%fps) print("Net FPS:%4.2f"%net_fps) cv2.destroyAllWindows() cap.release() out_video.release()
<video_detect_cv2.py>
And the following is a modified part of the code above to test the performance using TensorRT.
To run this code, you need some of triple-Mu's github source code that you downloaded earlier. Copy the model directory, config.py from triple-Mu's github source code in advance.
(yolov8) spypiggy@spypiggy-NX:~/src/yolov8$ cp -r ../YOLOv8-TensorRT/models ./
(yolov8) spypiggy@spypiggy-NX:~/src/yolov8$ cp ../YOLOv8-TensorRT/config.py ./
from ultralytics import YOLO import cv2 import time, sys from models import TRTModule # isort:skip from models.torch_utils import det_postprocess from models.utils import blob, letterbox, path_to_list from config import CLASSES, COLORS import torch colors = [(255,0 , 0), (0,255,0), (0,0,255)] font = cv2.FONT_HERSHEY_SIMPLEX fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') def draw(img, boxes): index = 0 for box in boxes.data: p1 = (int(box[0].item()), int(box[1].item())) p2 = (int(box[2].item()), int(box[3].item())) img = cv2.rectangle(img, p1, p2, colors[index % len(colors)], 3) text = label_map[int(box[5].item())] + " %4.2f"%(box[4].item()) cv2.putText(img, text, (p1[0], p1[1] - 10), font, fontScale = 1, color = colors[index % len(colors)], thickness = 2) index += 1 # cv2.imshow("draw", img) # cv2.waitKey(1) out_video.write(img) device = 'cuda:0' engine = "yolov8s.engine" # Load a model Engine = TRTModule(engine, device) H, W = Engine.inp_info[0].shape[-2:] Engine.set_desired(['num_dets', 'bboxes', 'scores', 'labels']) #label_map = model.names f = 0 net_total = 0.0 total = 0.0 cap = cv2.VideoCapture("./highway_traffic.mp4") # Skip first frame result ret, img = cap.read() h, w, c = img.shape img, ratio, dwdh = letterbox(img, (W, H)) rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) tensor = blob(rgb, return_seg=False) dwdh = torch.asarray(dwdh * 2, dtype=torch.float32, device=device) tensor = torch.asarray(tensor, device=device) data = Engine(tensor) fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') out_video = cv2.VideoWriter('./trt_result.mp4', fourcc, cap.get(cv2.CAP_PROP_FPS), (w, h)) while cap.isOpened(): s = time.time() ret, img = cap.read() if ret == False: break draw = img.copy() img, ratio, dwdh = letterbox(img, (W, H)) rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) tensor = blob(rgb, return_seg=False) dwdh = torch.asarray(dwdh * 2, dtype=torch.float32, device=device) tensor = torch.asarray(tensor, device=device) net_s = time.time() data = Engine(tensor) net_e = time.time() bboxes, scores, labels = det_postprocess(data) bboxes -= dwdh bboxes /= ratio for (bbox, score, label) in zip(bboxes, scores, labels): bbox = bbox.round().int().tolist() cls_id = int(label) cls = CLASSES[cls_id] color = COLORS[cls] cv2.rectangle(draw, bbox[:2], bbox[2:], color, 2) cv2.putText(draw, f'{cls}:{score:.3f}', (bbox[0], bbox[1] - 2), cv2.FONT_HERSHEY_SIMPLEX, 0.75, [225, 255, 255], thickness=2) #cv2.imshow('result', draw) #cv2.waitKey(1) e = time.time() net_total += (net_e - net_s) total += (e - s) f += 1 out_video.write(draw) # fps = f / total net_fps = f / net_total print("Total processed frames:%d"%f) print("FPS:%4.2f"%fps) print("Net FPS:%4.2f"%net_fps) cv2.destroyAllWindows() cap.release() out_video.release()
<video_detect_cv_trt.py>
Now let's run the two programs and compare their performance.
(yolov8) spypiggy@spypiggy-NX:~/src/yolov8$ python video_detect_cv2.py ...... Total processed frames:1548 FPS:12.19 Net FPS:20.31 (yolov8) spypiggy@spypiggy-NX:~/src/yolov8$ python video_detect_cv_trt.py ...... Total processed frames:1548 FPS:26.56 Net FPS:46.12
It can be seen that there is a performance improvement of about 230%.
If you look at the two resulting videos, they show almost identical detection results. The figure below compares two similar frames from two videos. It can be seen that the results are almost identical.
Wrapping up
By converting the YOLOV8 model to TensorRT, I was able to create a model that could speed up more than 2x with a slight decrease in recognition rate.
Jetson series, including Xavier NX, uses Nvidia GPU, but it is true that performance is much lower than that of PCs with GPUs such as RTX series. And the relatively small amount of memory can further degrade performance. Therefore, using a TensorRT model converted to fp16 is one of the good ways to compensate for this weakness.
The source code can be downloaded from my GitHub.
댓글 없음:
댓글 쓰기