47 lines
1.5 KiB
Python
47 lines
1.5 KiB
Python
from ultralytics import YOLO
|
||
import os
|
||
|
||
|
||
def main():
|
||
"""
|
||
训练搭电设备分割模型(支持多边形标注)
|
||
如果标注是边界框格式,使用 yolov8n.pt
|
||
如果标注是多边形格式,使用 yolov8n-seg.pt
|
||
"""
|
||
# 检查标注格式,决定使用哪种模型
|
||
label_file = "datasets/搭电设备/train/labels/订单1804634_54_7089744.txt"
|
||
use_segmentation = False
|
||
|
||
if os.path.exists(label_file):
|
||
with open(label_file, 'r') as f:
|
||
line = f.readline().strip()
|
||
parts = line.split()
|
||
# 如果坐标数量 > 5 (class_id + 4个边界框坐标),说明是多边形格式
|
||
if len(parts) > 5:
|
||
use_segmentation = True
|
||
print("检测到多边形标注格式,使用分割模型")
|
||
else:
|
||
print("检测到边界框标注格式,使用检测模型")
|
||
|
||
# 根据标注格式选择模型
|
||
if use_segmentation:
|
||
model = YOLO("yolov8n-seg.pt") # 分割模型(支持多边形)
|
||
task = "segment"
|
||
else:
|
||
model = YOLO("yolov8n.pt") # 检测模型(边界框)
|
||
task = "detect"
|
||
|
||
# 开始训练
|
||
results = model.train(
|
||
data="datasets/搭电设备/data.yaml",
|
||
epochs=100,
|
||
batch=8,
|
||
imgsz=640,
|
||
device="cpu", # 使用GPU,如果是CPU则设为"cpu"
|
||
task=task # 明确指定任务类型
|
||
)
|
||
|
||
print(f"训练完成!模型保存在: {results.save_dir}")
|
||
|
||
if __name__ == '__main__':
|
||
main() |