from ultralytics import YOLO import os def main(): """ 训练搭电设备分割模型(支持多边形标注) 如果标注是边界框格式,使用 yolov8n.pt 如果标注是多边形格式,使用 yolov8n-seg.pt """ # 检查标注格式,决定使用哪种模型 label_file = "datasets/搭电设备/train/labels/订单1804634_54_7089744.txt" use_segmentation = False if os.path.exists(label_file): with open(label_file, 'r') as f: line = f.readline().strip() parts = line.split() # 如果坐标数量 > 5 (class_id + 4个边界框坐标),说明是多边形格式 if len(parts) > 5: use_segmentation = True print("检测到多边形标注格式,使用分割模型") else: print("检测到边界框标注格式,使用检测模型") # 根据标注格式选择模型 if use_segmentation: model = YOLO("yolov8n-seg.pt") # 分割模型(支持多边形) task = "segment" else: model = YOLO("yolov8n.pt") # 检测模型(边界框) task = "detect" # 开始训练 results = model.train( data="datasets/搭电设备/data.yaml", epochs=100, batch=8, imgsz=640, device="cpu", # 使用GPU,如果是CPU则设为"cpu" task=task # 明确指定任务类型 ) print(f"训练完成!模型保存在: {results.save_dir}") if __name__ == '__main__': main()