jjsos_JJdetection/train_jumper_cable.py

47 lines
1.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from ultralytics import YOLO
import os
def main():
"""
训练搭电设备分割模型(支持多边形标注)
如果标注是边界框格式,使用 yolov8n.pt
如果标注是多边形格式,使用 yolov8n-seg.pt
"""
# 检查标注格式,决定使用哪种模型
label_file = "datasets/搭电设备/train/labels/订单1804634_54_7089744.txt"
use_segmentation = False
if os.path.exists(label_file):
with open(label_file, 'r') as f:
line = f.readline().strip()
parts = line.split()
# 如果坐标数量 > 5 (class_id + 4个边界框坐标),说明是多边形格式
if len(parts) > 5:
use_segmentation = True
print("检测到多边形标注格式,使用分割模型")
else:
print("检测到边界框标注格式,使用检测模型")
# 根据标注格式选择模型
if use_segmentation:
model = YOLO("yolov8n-seg.pt") # 分割模型(支持多边形)
task = "segment"
else:
model = YOLO("yolov8n.pt") # 检测模型(边界框)
task = "detect"
# 开始训练
results = model.train(
data="datasets/搭电设备/data.yaml",
epochs=100,
batch=8,
imgsz=640,
device="cpu", # 使用GPU如果是CPU则设为"cpu"
task=task # 明确指定任务类型
)
print(f"训练完成!模型保存在: {results.save_dir}")
if __name__ == '__main__':
main()