jjsos_JJdetection/train_all_datasets.py

123 lines
3.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
通用训练脚本 - 可以快速训练不同的数据集
使用方法python train_all_datasets.py 数据集名称
示例python train_all_datasets.py 仪表盘
"""
import sys
import os
from train_segmentation import train_segmentation_model
# 数据集配置字典
DATASET_CONFIGS = {
"搭电设备": {
"path": "datasets/搭电设备/data.yaml",
"model": "yolov8n-seg.pt",
"epochs": 100,
"batch": 8,
},
"仪表盘": {
"path": "datasets/仪表盘/data.yaml",
"model": "yolov8n-seg.pt",
"epochs": 100,
"batch": 8,
},
"检测仪": {
"path": "datasets/检测仪/data.yaml",
"model": "yolov8n-seg.pt",
"epochs": 100,
"batch": 8,
},
"猫狗": {
"path": "datasets/猫狗/data.yaml",
"model": "yolov8n-seg.pt",
"epochs": 100,
"batch": 8,
},
"搭电测试": {
"path": "datasets/搭电测试/data.yaml",
"model": "yolov8n-seg.pt",
"epochs": 50,
"batch": 8,
},
}
def list_available_datasets():
"""列出所有可用的数据集"""
print("可用的数据集:")
print("-" * 50)
for name, config in DATASET_CONFIGS.items():
exists = "" if os.path.exists(config["path"]) else ""
print(f"{exists} {name:15} - {config['path']}")
print("-" * 50)
def train_dataset(dataset_name, custom_config=None):
"""
训练指定的数据集
Args:
dataset_name: 数据集名称
custom_config: 自定义配置字典(可选)
"""
if dataset_name not in DATASET_CONFIGS:
print(f"错误: 未找到数据集 '{dataset_name}'")
print("\n可用的数据集:")
for name in DATASET_CONFIGS.keys():
print(f" - {name}")
return
config = DATASET_CONFIGS[dataset_name].copy()
# 如果提供了自定义配置,合并
if custom_config:
config.update(custom_config)
# 检查数据集文件是否存在
if not os.path.exists(config["path"]):
print(f"错误: 数据集配置文件不存在: {config['path']}")
return
print(f"开始训练数据集: {dataset_name}")
print(f"配置文件: {config['path']}")
print("-" * 50)
# 开始训练
train_segmentation_model(
dataset_path=config["path"],
model_name=config["model"],
epochs=config["epochs"],
batch=config["batch"],
imgsz=640,
device="cpu" # 可以根据需要修改为 "cuda"
)
def main():
"""主函数"""
if len(sys.argv) < 2:
print("使用方法: python train_all_datasets.py <数据集名称>")
print("\n示例:")
print(" python train_all_datasets.py 搭电设备")
print(" python train_all_datasets.py 仪表盘")
print(" python train_all_datasets.py 检测仪")
print("\n" + "=" * 50)
list_available_datasets()
return
dataset_name = sys.argv[1]
# 检查是否是特殊命令
if dataset_name == "list":
list_available_datasets()
return
# 开始训练
train_dataset(dataset_name)
if __name__ == '__main__':
main()