""" 验证YOLO分割标注格式的工具脚本 用于检查从CVAT导出的多边形标注是否正确 """ import os import cv2 import numpy as np from pathlib import Path def verify_yolo_seg_format(label_path, img_path=None, show_visualization=False): """ 验证YOLO分割标注格式 Args: label_path: 标注文件路径 img_path: 对应的图片路径(可选,用于可视化) show_visualization: 是否显示可视化结果 Returns: dict: 验证结果信息 """ result = { "valid": False, "class_id": None, "num_points": 0, "coords_count": 0, "errors": [], "warnings": [] } if not os.path.exists(label_path): result["errors"].append(f"文件不存在: {label_path}") return result try: with open(label_path, 'r', encoding='utf-8') as f: lines = f.readlines() except Exception as e: result["errors"].append(f"读取文件失败: {e}") return result if not lines: result["warnings"].append("文件为空") return result # 解析第一行(通常一个对象一行) line = lines[0].strip() if not line: result["warnings"].append("第一行为空") return result parts = line.split() if len(parts) < 7: # 至少需要 class_id + 3个点(6个值) result["errors"].append(f"格式错误: 坐标点数量不足 (需要至少3个点,当前只有{len(parts)-1}个值)") return result try: class_id = int(parts[0]) coords = list(map(float, parts[1:])) except ValueError as e: result["errors"].append(f"格式错误: 无法解析数字 - {e}") return result # 检查坐标数量(必须是偶数) if len(coords) % 2 != 0: result["errors"].append(f"格式错误: 坐标数量不是偶数 ({len(coords)}个值)") return result num_points = len(coords) // 2 # 检查坐标范围 invalid_coords = [c for c in coords if c < 0 or c > 1] if invalid_coords: result["warnings"].append(f"坐标超出[0,1]范围: {len(invalid_coords)}个值") result["valid"] = True result["class_id"] = class_id result["num_points"] = num_points result["coords_count"] = len(coords) # 可视化验证 if show_visualization and img_path and os.path.exists(img_path): visualize_seg_label(img_path, label_path) return result def visualize_seg_label(img_path, label_path): """ 可视化分割标注,在图片上绘制多边形 Args: img_path: 图片路径 label_path: 标注文件路径 """ img = cv2.imread(img_path) if img is None: print(f"无法读取图片: {img_path}") return h, w = img.shape[:2] with open(label_path, 'r', encoding='utf-8') as f: lines = f.readlines() for line_idx, line in enumerate(lines): parts = line.strip().split() if len(parts) < 7: continue class_id = int(parts[0]) points = list(map(float, parts[1:])) # 转换为像素坐标 pixel_points = [] for i in range(0, len(points), 2): x = int(points[i] * w) y = int(points[i+1] * h) pixel_points.append([x, y]) # 绘制多边形 pts = np.array(pixel_points, np.int32) # 填充多边形(半透明) overlay = img.copy() cv2.fillPoly(overlay, [pts], (0, 255, 0)) cv2.addWeighted(overlay, 0.3, img, 0.7, 0, img) # 绘制多边形边界 cv2.polylines(img, [pts], True, (0, 255, 0), 2) # 添加类别标签 if len(pixel_points) > 0: cv2.putText(img, f"Class {class_id}", (pixel_points[0][0], pixel_points[0][1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2) # 显示结果 cv2.imshow('分割标注验证', img) print("按任意键关闭窗口...") cv2.waitKey(0) cv2.destroyAllWindows() def batch_verify_labels(labels_dir, images_dir=None, sample_count=5): """ 批量验证标注文件 Args: labels_dir: 标注文件目录 images_dir: 图片目录(可选) sample_count: 随机采样验证的文件数量(0表示全部) """ label_files = list(Path(labels_dir).glob("*.txt")) if not label_files: print(f"未找到标注文件: {labels_dir}") return if sample_count > 0 and len(label_files) > sample_count: import random label_files = random.sample(label_files, sample_count) print(f"随机采样 {sample_count} 个文件进行验证...") print(f"开始验证 {len(label_files)} 个标注文件...") print("-" * 60) valid_count = 0 error_count = 0 for label_file in label_files: # 尝试找到对应的图片文件 img_path = None if images_dir: for ext in ['.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG']: img_path = Path(images_dir) / (label_file.stem + ext) if img_path.exists(): break img_path = None result = verify_yolo_seg_format(str(label_file), str(img_path) if img_path else None) status = "✓" if result["valid"] else "✗" print(f"{status} {label_file.name}") print(f" 类别ID: {result['class_id']}, 顶点数: {result['num_points']}, 坐标数: {result['coords_count']}") if result["errors"]: print(f" 错误: {', '.join(result['errors'])}") error_count += 1 elif result["warnings"]: print(f" 警告: {', '.join(result['warnings'])}") if result["valid"]: valid_count += 1 print() print("-" * 60) print(f"验证完成: {valid_count}/{len(label_files)} 个文件格式正确") if error_count > 0: print(f"发现 {error_count} 个文件有错误") def main(): """主函数 - 示例用法""" import sys if len(sys.argv) > 1: # 命令行模式 labels_dir = sys.argv[1] images_dir = sys.argv[2] if len(sys.argv) > 2 else None batch_verify_labels(labels_dir, images_dir) else: # 交互模式 - 验证搭电设备数据集 print("验证搭电设备数据集的标注格式...") labels_dir = "datasets/搭电设备/train/labels" images_dir = "datasets/搭电设备/train/images" if os.path.exists(labels_dir): batch_verify_labels(labels_dir, images_dir, sample_count=10) else: print(f"目录不存在: {labels_dir}") print("\n使用方法:") print(" python verify_seg_labels.py [images_dir]") print("\n示例:") print(" python verify_seg_labels.py datasets/搭电设备/train/labels datasets/搭电设备/train/images") if __name__ == '__main__': main()