jjsos_JJdetection/train_segmentation.py

99 lines
3.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from ultralytics import YOLO
import os
def train_segmentation_model(
dataset_path="datasets/xianshiping/data.yaml",
model_name="yolov8n-seg.pt",
epochs=100,
batch=8,
imgsz=640,
device="cpu"
):
"""
训练YOLO分割模型实例分割支持多边形标注
Args:
dataset_path: 数据集配置文件路径
model_name: 预训练模型名称
可选: "yolov8n-seg.pt", "yolov8s-seg.pt", "yolov8m-seg.pt",
"yolov8l-seg.pt", "yolov8x-seg.pt"
"yolo11n-seg.pt", "yolo11s-seg.pt"
epochs: 训练轮数
batch: 批次大小
imgsz: 图片尺寸
device: 设备,"cpu""cuda""0" (GPU编号)
"""
# 检查数据集配置文件是否存在
if not os.path.exists(dataset_path):
print(f"错误: 数据集配置文件不存在: {dataset_path}")
return
# 检查模型文件是否存在,如果不存在会自动下载
print(f"加载模型: {model_name}")
model = YOLO(model_name)
print(f"开始训练分割模型...")
print(f"数据集: {dataset_path}")
print(f"训练轮数: {epochs}")
print(f"批次大小: {batch}")
print(f"图片尺寸: {imgsz}")
print(f"设备: {device}")
# 自动检测并使用最佳设备
import torch
if device == "cpu":
# 检查是否有Apple Silicon GPU (MPS)
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
device = "mps"
print("✓ 检测到Apple Silicon GPU使用MPS加速训练")
# 检查是否有CUDA GPU
elif torch.cuda.is_available():
device = "cuda"
print("✓ 检测到CUDA GPU使用GPU加速训练")
else:
print("⚠ 使用CPU训练较慢建议使用GPU")
print("-" * 50)
# 开始训练
results = model.train(
data=dataset_path,
epochs=epochs,
batch=batch,
imgsz=imgsz,
device=device,
task="segment", # 明确指定分割任务
project="runs/segment", # 保存到segment目录
name="train", # 训练名称
)
print("-" * 50)
print("训练完成!")
print(f"模型保存在: {results.save_dir}")
return results
def main():
"""主函数 - 训练检测仪数据集"""
# 数据集信息:
# - 训练集: 978张图片
# - 验证集: 245张图片
# - 总计: 1223张图片
# - 类别: 6个搭电线、拖车上有车、仪表盘、轮胎、平安马甲、检测仪
train_segmentation_model(
dataset_path="datasets/xianshiping/data.yaml", # 数据集配置文件路径
model_name="yolov8n-seg.pt", # 使用YOLOv8 nano分割模型
epochs=150, # 1223张照片建议150轮配合早停
batch=8, # CPU/MPS建议8GPU可以16-32
imgsz=640, # 图片尺寸
device="cpu" # 自动检测并使用最佳设备MPS/GPU/CPU
)
if __name__ == '__main__':
main()