目标追踪:使用ByteTrack进行目标检测和跟踪

共 8710字,需浏览 18分钟

 ·

2024-07-14 10:05

点击上方小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

BYTE算法是一种简单而有效的关联方法,通过关联几乎每个检测框而不仅仅是高分的检测框来跟踪对象。这篇博客的目标是介绍ByteTrack以及多目标跟踪(MOT)的技术。我们还将介绍在样本视频上使用ByteTrack跟踪运行YOLOv8目标检测。


多目标跟踪(MOT)

你可能听说过目标检测,有许多算法如Faster RCNN、SSD和YOLO的各个版本,它们可以以很高的准确性检测物体。但有一个更新的问题是多目标跟踪。基本上,你将传递一个视频流,对于每一帧,你需要检测对象并分配一个“对象ID”,在下一帧中,如果检测到相同的对象,需要分配相同的对象ID。有许多用于MOT的算法,如SORT(简单在线和实时跟踪)、DeepSort、StrongSort等。


有各种用于目标跟踪的方法,包括:

1. 基于特征的跟踪:这涉及基于其特征(如颜色、形状、纹理等)进行跟踪。

2. 模板匹配:正如其名称所示,该方法使用预定义的模板在每个视频序列中进行匹配。

3. 相关性跟踪:该方法用于计算目标对象与后续帧中候选区域的相似性。

4. 基于深度学习的跟踪:该方法使用在大型数据集上训练的神经网络,以实时检测和跟踪对象。


在你可能对MOT有了一些基本的了解。让我们尝试进入ByteTrack并尝试理解为什么它是比DeepSort等更好的目标跟踪。


ByteSort

首先,我们将了解先前MOT算法的问题,然后理解ByteSort的逻辑。


其他MOT算法的问题:

  • 低置信度检测框:其他MOT算法的第一个问题是删除低置信度的检测框。而ByteTrack考虑了低置信度的检测框。为什么呢?

因为低置信度的检测框有时表示物体的存在,例如被遮挡的物体。过滤这些对象会在MOT中引入不可逆的错误,导致不可忽视的漏检和碎片化的轨迹。让我们通过例子来理解:

Detection boxes

 如图 t1 中所示,我们初始化三个不同的tracklet,因为它们的分数高于0.5。但在 t2 和 t3 中,分数从 0.8 下降到 0.4,然后再下降到 0.1。

Tracklets by associating high scores detection boxes

 这些检测框将通过阈值机制被消除,红色轨迹随之消失,如图 b 所示。但如果我们考虑所有的检测框,将引入更多的假阳性,例如图 a 中最右侧的框。这带来了第二个问题:

  • 假阳性边框的考虑:在这里识别到与tracklets相似性提供了在低分检测框中区分对象和背景的强关联。

Tracklets by associating every detection boxes


例如,如图 c 所示,通过运动预测的框(虚线)将两个低分检测框与tracklets匹配,从而正确恢复了对象。由于背景框没有匹配的tracklet,因此将其移除。因此,为了在匹配过程中使用高分到低分的检测框,这种简单而有效的关联方法被称为BYTE,因为每个检测框是tracklet的基本单元。


首先,它根据运动或外观相似性将高分检测框与tracklets匹配。然后,它采用卡尔曼滤波器来预测tracklets在下一帧的位置。然后,可以使用IoU或Re-ID特征距离计算预测框与检测框之间的相似性。在第二个匹配步骤中,使用相同的运动相似性匹配低分检测和未匹配的tracklets,即红框中的tracklets。让我们尝试理解数据关联,这是MOT算法的核心。


数据关联


这是多目标跟踪的核心,首先计算tracklets和检测框之间的相似性,并根据相似性应用不同的策略进行匹配。

  • 相似性度量:对于关联,位置、运动和外观是三个重要的线索。SORT以非常简单的方式使用位置和运动线索。它采用卡尔曼滤波器来预测下一帧中的tracklets,然后计算检测框和预测框之间的IoU作为相似性。但是位置和运动线索适用于短程匹配。但对于长程匹配,外观相似性是有帮助的。例如,长时间被遮挡的对象将使用外观相似性进行识别。外观相似性通过Re-ID特征的余弦相似度来计算。DeepSort使用一个独立的深度学习模型进行外观相似性。


  • 匹配策略:在相似性计算后,匹配策略用于为对象分配ID。这可以通过匈牙利算法或贪婪分配来完成。SORT通过一次匹配将检测框与tracklets匹配。而DeepSort使用级联匹配策略,首先将检测框与最近的trackers匹配,然后匹配失去的tracklets。


BYTE算法



(d)BYTETrack的伪代码。(绿色是该方法的关键)

BYTE算法的输入是视频序列和检测器。还有一个检测阈值。该算法输出视频的轨迹T,每一帧包含对象的边界框和ID。对于视频中的每一帧,首先使用检测器Det预测检测框和预测分数。然后根据检测分数阈值将检测框分为Det(high)和Det(low)两类。


在分离了检测框之后,对每个轨迹T应用卡尔曼滤波器来预测当前帧的新位置。首先,在高检测框上应用关联,然后在剩余的低检测框上应用关联。BYTE的主要亮点是,它非常灵活,可以与不同的关联方法兼容。


性能


ByteTrack优于SORT和DeepSORT算法。ByteTrack的MOTA(多目标跟踪准确性)为76.6,而SORT和DeepSort分别为74.6和75.4。现在,你可能已经理解了ByteTrack的主要概念。我想这很简单。让我们尝试在实际项目中应用它。


使用YOLOv8检测器的ByteTrack


在这里,我们将看到如何使用YOLOv8检测器跟踪道路上的车辆,并计算进出的车辆数。

如你所见,每辆新车都被分配了一个ID、一个类名和检测概率。使用in和out,你可以看到进出交通的计数。让我们看看这个实现的代码:

import supervision as svfrom ultralytics import YOLO from tqdm import tqdmimport argparseimport numpy as np
tracker = sv.ByteTrack() def process_video( source_weights_path: str, source_video_path: str, target_video_path: str, confidence_threshold: float = 0.3, iou_threshold: float = 0.7) -> None: model = YOLO(source_weights_path) # Load YOLO model classes = list(model.names.values()) # Class names LINE_STARTS = sv.Point(0,500) # Line start point for count in/out vehicle LINE_END = sv.Point(1280, 500) # Line end point for count in/out vehicle tracker = sv.ByteTrack() # Bytetracker instance box_annotator = sv.BoundingBoxAnnotator() # BondingBox annotator instance label_annotator = sv.LabelAnnotator() # Label annotator instance frame_generator = sv.get_video_frames_generator(source_path=source_video_path) # for generating frames from video video_info = sv.VideoInfo.from_video_path(video_path=source_video_path) line_counter = sv.LineZone(start=LINE_STARTS, end = LINE_END) line_annotator = sv.LineZoneAnnotator(thickness=2, text_thickness=2, text_scale= 0.5)
with sv.VideoSink(target_path=target_video_path, video_info=video_info) as sink: for frame in tqdm(frame_generator, total= video_info.total_frames): # Getting result from model results = model(frame, verbose=False, conf= confidence_threshold, iou = iou_threshold)[0] detections = sv.Detections.from_ultralytics(results) # Getting detections #Filtering classes for car and truck only instead of all COCO classes. detections = detections[np.where((detections.class_id==2)|(detections.class_id==7))] detections = tracker.update_with_detections(detections) # Updating detection to Bytetracker # Annotating detection boxes annotated_frame = box_annotator.annotate(scene = frame.copy(), detections= detections)
#Prepare labels labels = [] for index in range(len(detections.class_id)): # creating labels as per required. labels.append("#" + str(detections.tracker_id[index]) + " " + classes[detections.class_id[index]] + " "+ str(round(detections.confidence[index],2)) ) # Line counter in/out trigger line_counter.trigger(detections=detections) # Annotating labels annotated_label_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels) # Annotating line labels line_annotate_frame = line_annotator.annotate(frame=annotated_label_frame, line_counter=line_counter) sink.write_frame(frame = line_annotate_frame)
if __name__ == "__main__": parser = argparse.ArgumentParser("video processing with YOLO and ByteTrack") parser.add_argument( "--source_weights_path", required=True, help="Path to the source weights file", type=str ) parser.add_argument( "--source_video_path", required=True, help="Path to the source video file", type = str ) parser.add_argument( "--target_video_path", required=True, help="Path to the target video file", type= str ) parser.add_argument( "--confidence_threshold", default = 0.3, help= "Confidence threshold for the model", type=float ) parser.add_argument( "--iou_threshold", default=0.7, help="Iou threshold for the model", type= float ) args = parser.parse_args() process_video( source_weights_path=args.source_weights_path, source_video_path= args.source_video_path, target_video_path=args.target_video_path, confidence_threshold=args.confidence_threshold, iou_threshold=args.iou_threshold )

在这里,我使用了YOLOv8 Ultralytics库来加载在COCO数据集上训练的YOLO模型。Supervision库用于加载ByteTrack和其他视觉任务,如标注、车辆计数等。

你只需通过传递视频作为输入运行此命令:

python sv_bytetracker_yolo.py --source_weights_path yolov8m.pt --source_video_path test_video.mp4 --target_video_path test_pred.mp4 --confidence_threshold 0.1

如果你想跟踪其他类别,可以从代码中删除类别过滤器。


应用场景

因此,我们已经完全了解了ByteTrack。它可以在各种应用和行业中使用,例如:

1. 汽车行业:用于跟踪道路上的车辆进行交通分析,例如任何车辆是否朝错误的方向行驶,四路口的交通情况等。

2. 生产行业:可以在生产线上用于计数和跟踪生产物品。

3. 购物中的客户互动:跟踪客户的移动,了解客户对哪种产品或哪个类别更感兴趣。他们持有产品的时间有多长,最终是购买还是放回货架。

4. 增强客户体验:在客户看起来困惑或寻找产品时间过长时进行识别。


总结


1. 有各种MOT模型,如SORT、DeepSort、ByteTrack等。

2. 有各种用于对象跟踪的方法/技术,包括基于特征的跟踪、模板匹配、相关性跟踪和基于深度学习的跟踪。

3. ByteTrack算法考虑了低分检测框(与高分检测框一起)进行对象跟踪。

4. 数据关联应用于每个检测。

5. 在数据关联中,生成了tracklet和检测框之间的相似性。然后根据相似性应用不同的策略进行匹配。

6. 相似性可以通过IoU或由卡尔曼滤波器的tracklet预测的Re-ID计算。

7. 对于长距离,外观相似性是有用的。

8. 在匹配策略中使用了匈牙利算法。

9.Byte首先对高分检测框应用关联,然后对低分检测框应用关联。

       
下载1:OpenCV-Contrib扩展模块中文版教程
在「小白学视觉」公众号后台回复:扩展模块中文教程即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。

下载2:Python视觉实战项目52讲
小白学视觉公众号后台回复:Python视觉实战项目即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。

下载3:OpenCV实战项目20讲
小白学视觉公众号后台回复:OpenCV实战项目20讲即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。

交流群


欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~


浏览 70
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报