123 lines
3.2 KiB
Python
123 lines
3.2 KiB
Python
"""
|
||
通用训练脚本 - 可以快速训练不同的数据集
|
||
使用方法:python train_all_datasets.py 数据集名称
|
||
示例:python train_all_datasets.py 仪表盘
|
||
"""
|
||
import sys
|
||
import os
|
||
from train_segmentation import train_segmentation_model
|
||
|
||
|
||
# 数据集配置字典
|
||
DATASET_CONFIGS = {
|
||
"搭电设备": {
|
||
"path": "datasets/搭电设备/data.yaml",
|
||
"model": "yolov8n-seg.pt",
|
||
"epochs": 100,
|
||
"batch": 8,
|
||
},
|
||
"仪表盘": {
|
||
"path": "datasets/仪表盘/data.yaml",
|
||
"model": "yolov8n-seg.pt",
|
||
"epochs": 100,
|
||
"batch": 8,
|
||
},
|
||
"检测仪": {
|
||
"path": "datasets/检测仪/data.yaml",
|
||
"model": "yolov8n-seg.pt",
|
||
"epochs": 100,
|
||
"batch": 8,
|
||
},
|
||
"猫狗": {
|
||
"path": "datasets/猫狗/data.yaml",
|
||
"model": "yolov8n-seg.pt",
|
||
"epochs": 100,
|
||
"batch": 8,
|
||
},
|
||
"搭电测试": {
|
||
"path": "datasets/搭电测试/data.yaml",
|
||
"model": "yolov8n-seg.pt",
|
||
"epochs": 50,
|
||
"batch": 8,
|
||
},
|
||
}
|
||
|
||
|
||
def list_available_datasets():
|
||
"""列出所有可用的数据集"""
|
||
print("可用的数据集:")
|
||
print("-" * 50)
|
||
for name, config in DATASET_CONFIGS.items():
|
||
exists = "✓" if os.path.exists(config["path"]) else "✗"
|
||
print(f"{exists} {name:15} - {config['path']}")
|
||
print("-" * 50)
|
||
|
||
|
||
def train_dataset(dataset_name, custom_config=None):
|
||
"""
|
||
训练指定的数据集
|
||
|
||
Args:
|
||
dataset_name: 数据集名称
|
||
custom_config: 自定义配置字典(可选)
|
||
"""
|
||
if dataset_name not in DATASET_CONFIGS:
|
||
print(f"错误: 未找到数据集 '{dataset_name}'")
|
||
print("\n可用的数据集:")
|
||
for name in DATASET_CONFIGS.keys():
|
||
print(f" - {name}")
|
||
return
|
||
|
||
config = DATASET_CONFIGS[dataset_name].copy()
|
||
|
||
# 如果提供了自定义配置,合并
|
||
if custom_config:
|
||
config.update(custom_config)
|
||
|
||
# 检查数据集文件是否存在
|
||
if not os.path.exists(config["path"]):
|
||
print(f"错误: 数据集配置文件不存在: {config['path']}")
|
||
return
|
||
|
||
print(f"开始训练数据集: {dataset_name}")
|
||
print(f"配置文件: {config['path']}")
|
||
print("-" * 50)
|
||
|
||
# 开始训练
|
||
train_segmentation_model(
|
||
dataset_path=config["path"],
|
||
model_name=config["model"],
|
||
epochs=config["epochs"],
|
||
batch=config["batch"],
|
||
imgsz=640,
|
||
device="cpu" # 可以根据需要修改为 "cuda"
|
||
)
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
if len(sys.argv) < 2:
|
||
print("使用方法: python train_all_datasets.py <数据集名称>")
|
||
print("\n示例:")
|
||
print(" python train_all_datasets.py 搭电设备")
|
||
print(" python train_all_datasets.py 仪表盘")
|
||
print(" python train_all_datasets.py 检测仪")
|
||
print("\n" + "=" * 50)
|
||
list_available_datasets()
|
||
return
|
||
|
||
dataset_name = sys.argv[1]
|
||
|
||
# 检查是否是特殊命令
|
||
if dataset_name == "list":
|
||
list_available_datasets()
|
||
return
|
||
|
||
# 开始训练
|
||
train_dataset(dataset_name)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|
||
|