""" 通用训练脚本 - 可以快速训练不同的数据集 使用方法:python train_all_datasets.py 数据集名称 示例:python train_all_datasets.py 仪表盘 """ import sys import os from train_segmentation import train_segmentation_model # 数据集配置字典 DATASET_CONFIGS = { "搭电设备": { "path": "datasets/搭电设备/data.yaml", "model": "yolov8n-seg.pt", "epochs": 100, "batch": 8, }, "仪表盘": { "path": "datasets/仪表盘/data.yaml", "model": "yolov8n-seg.pt", "epochs": 100, "batch": 8, }, "检测仪": { "path": "datasets/检测仪/data.yaml", "model": "yolov8n-seg.pt", "epochs": 100, "batch": 8, }, "猫狗": { "path": "datasets/猫狗/data.yaml", "model": "yolov8n-seg.pt", "epochs": 100, "batch": 8, }, "搭电测试": { "path": "datasets/搭电测试/data.yaml", "model": "yolov8n-seg.pt", "epochs": 50, "batch": 8, }, } def list_available_datasets(): """列出所有可用的数据集""" print("可用的数据集:") print("-" * 50) for name, config in DATASET_CONFIGS.items(): exists = "✓" if os.path.exists(config["path"]) else "✗" print(f"{exists} {name:15} - {config['path']}") print("-" * 50) def train_dataset(dataset_name, custom_config=None): """ 训练指定的数据集 Args: dataset_name: 数据集名称 custom_config: 自定义配置字典(可选) """ if dataset_name not in DATASET_CONFIGS: print(f"错误: 未找到数据集 '{dataset_name}'") print("\n可用的数据集:") for name in DATASET_CONFIGS.keys(): print(f" - {name}") return config = DATASET_CONFIGS[dataset_name].copy() # 如果提供了自定义配置,合并 if custom_config: config.update(custom_config) # 检查数据集文件是否存在 if not os.path.exists(config["path"]): print(f"错误: 数据集配置文件不存在: {config['path']}") return print(f"开始训练数据集: {dataset_name}") print(f"配置文件: {config['path']}") print("-" * 50) # 开始训练 train_segmentation_model( dataset_path=config["path"], model_name=config["model"], epochs=config["epochs"], batch=config["batch"], imgsz=640, device="cpu" # 可以根据需要修改为 "cuda" ) def main(): """主函数""" if len(sys.argv) < 2: print("使用方法: python train_all_datasets.py <数据集名称>") print("\n示例:") print(" python train_all_datasets.py 搭电设备") print(" python train_all_datasets.py 仪表盘") print(" python train_all_datasets.py 检测仪") print("\n" + "=" * 50) list_available_datasets() return dataset_name = sys.argv[1] # 检查是否是特殊命令 if dataset_name == "list": list_available_datasets() return # 开始训练 train_dataset(dataset_name) if __name__ == '__main__': main()