99 lines
3.1 KiB
Python
99 lines
3.1 KiB
Python
from ultralytics import YOLO
|
||
import os
|
||
|
||
|
||
def train_segmentation_model(
|
||
dataset_path="datasets/xianshiping/data.yaml",
|
||
model_name="yolov8n-seg.pt",
|
||
epochs=100,
|
||
batch=8,
|
||
imgsz=640,
|
||
device="cpu"
|
||
):
|
||
"""
|
||
训练YOLO分割模型(实例分割,支持多边形标注)
|
||
|
||
Args:
|
||
dataset_path: 数据集配置文件路径
|
||
model_name: 预训练模型名称
|
||
可选: "yolov8n-seg.pt", "yolov8s-seg.pt", "yolov8m-seg.pt",
|
||
"yolov8l-seg.pt", "yolov8x-seg.pt"
|
||
或 "yolo11n-seg.pt", "yolo11s-seg.pt" 等
|
||
epochs: 训练轮数
|
||
batch: 批次大小
|
||
imgsz: 图片尺寸
|
||
device: 设备,"cpu" 或 "cuda" 或 "0" (GPU编号)
|
||
"""
|
||
# 检查数据集配置文件是否存在
|
||
if not os.path.exists(dataset_path):
|
||
print(f"错误: 数据集配置文件不存在: {dataset_path}")
|
||
return
|
||
|
||
# 检查模型文件是否存在,如果不存在会自动下载
|
||
print(f"加载模型: {model_name}")
|
||
model = YOLO(model_name)
|
||
|
||
print(f"开始训练分割模型...")
|
||
print(f"数据集: {dataset_path}")
|
||
print(f"训练轮数: {epochs}")
|
||
print(f"批次大小: {batch}")
|
||
print(f"图片尺寸: {imgsz}")
|
||
print(f"设备: {device}")
|
||
|
||
# 自动检测并使用最佳设备
|
||
import torch
|
||
if device == "cpu":
|
||
# 检查是否有Apple Silicon GPU (MPS)
|
||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||
device = "mps"
|
||
print("✓ 检测到Apple Silicon GPU,使用MPS加速训练")
|
||
# 检查是否有CUDA GPU
|
||
elif torch.cuda.is_available():
|
||
device = "cuda"
|
||
print("✓ 检测到CUDA GPU,使用GPU加速训练")
|
||
else:
|
||
print("⚠ 使用CPU训练(较慢,建议使用GPU)")
|
||
|
||
print("-" * 50)
|
||
|
||
# 开始训练
|
||
results = model.train(
|
||
data=dataset_path,
|
||
epochs=epochs,
|
||
batch=batch,
|
||
imgsz=imgsz,
|
||
device=device,
|
||
task="segment", # 明确指定分割任务
|
||
project="runs/segment", # 保存到segment目录
|
||
name="train", # 训练名称
|
||
)
|
||
|
||
print("-" * 50)
|
||
print("训练完成!")
|
||
print(f"模型保存在: {results.save_dir}")
|
||
|
||
return results
|
||
|
||
|
||
def main():
|
||
"""主函数 - 训练检测仪数据集"""
|
||
# 数据集信息:
|
||
# - 训练集: 978张图片
|
||
# - 验证集: 245张图片
|
||
# - 总计: 1223张图片
|
||
# - 类别: 6个(搭电线、拖车上有车、仪表盘、轮胎、平安马甲、检测仪)
|
||
|
||
train_segmentation_model(
|
||
dataset_path="datasets/xianshiping/data.yaml", # 数据集配置文件路径
|
||
model_name="yolov8n-seg.pt", # 使用YOLOv8 nano分割模型
|
||
epochs=150, # 1223张照片,建议150轮(配合早停)
|
||
batch=8, # CPU/MPS建议8,GPU可以16-32
|
||
imgsz=640, # 图片尺寸
|
||
device="cpu" # 自动检测并使用最佳设备(MPS/GPU/CPU)
|
||
)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|
||
|