231 lines
7.0 KiB
Python
231 lines
7.0 KiB
Python
"""
|
||
验证YOLO分割标注格式的工具脚本
|
||
用于检查从CVAT导出的多边形标注是否正确
|
||
"""
|
||
import os
|
||
import cv2
|
||
import numpy as np
|
||
from pathlib import Path
|
||
|
||
|
||
def verify_yolo_seg_format(label_path, img_path=None, show_visualization=False):
|
||
"""
|
||
验证YOLO分割标注格式
|
||
|
||
Args:
|
||
label_path: 标注文件路径
|
||
img_path: 对应的图片路径(可选,用于可视化)
|
||
show_visualization: 是否显示可视化结果
|
||
|
||
Returns:
|
||
dict: 验证结果信息
|
||
"""
|
||
result = {
|
||
"valid": False,
|
||
"class_id": None,
|
||
"num_points": 0,
|
||
"coords_count": 0,
|
||
"errors": [],
|
||
"warnings": []
|
||
}
|
||
|
||
if not os.path.exists(label_path):
|
||
result["errors"].append(f"文件不存在: {label_path}")
|
||
return result
|
||
|
||
try:
|
||
with open(label_path, 'r', encoding='utf-8') as f:
|
||
lines = f.readlines()
|
||
except Exception as e:
|
||
result["errors"].append(f"读取文件失败: {e}")
|
||
return result
|
||
|
||
if not lines:
|
||
result["warnings"].append("文件为空")
|
||
return result
|
||
|
||
# 解析第一行(通常一个对象一行)
|
||
line = lines[0].strip()
|
||
if not line:
|
||
result["warnings"].append("第一行为空")
|
||
return result
|
||
|
||
parts = line.split()
|
||
if len(parts) < 7: # 至少需要 class_id + 3个点(6个值)
|
||
result["errors"].append(f"格式错误: 坐标点数量不足 (需要至少3个点,当前只有{len(parts)-1}个值)")
|
||
return result
|
||
|
||
try:
|
||
class_id = int(parts[0])
|
||
coords = list(map(float, parts[1:]))
|
||
except ValueError as e:
|
||
result["errors"].append(f"格式错误: 无法解析数字 - {e}")
|
||
return result
|
||
|
||
# 检查坐标数量(必须是偶数)
|
||
if len(coords) % 2 != 0:
|
||
result["errors"].append(f"格式错误: 坐标数量不是偶数 ({len(coords)}个值)")
|
||
return result
|
||
|
||
num_points = len(coords) // 2
|
||
|
||
# 检查坐标范围
|
||
invalid_coords = [c for c in coords if c < 0 or c > 1]
|
||
if invalid_coords:
|
||
result["warnings"].append(f"坐标超出[0,1]范围: {len(invalid_coords)}个值")
|
||
|
||
result["valid"] = True
|
||
result["class_id"] = class_id
|
||
result["num_points"] = num_points
|
||
result["coords_count"] = len(coords)
|
||
|
||
# 可视化验证
|
||
if show_visualization and img_path and os.path.exists(img_path):
|
||
visualize_seg_label(img_path, label_path)
|
||
|
||
return result
|
||
|
||
|
||
def visualize_seg_label(img_path, label_path):
|
||
"""
|
||
可视化分割标注,在图片上绘制多边形
|
||
|
||
Args:
|
||
img_path: 图片路径
|
||
label_path: 标注文件路径
|
||
"""
|
||
img = cv2.imread(img_path)
|
||
if img is None:
|
||
print(f"无法读取图片: {img_path}")
|
||
return
|
||
|
||
h, w = img.shape[:2]
|
||
|
||
with open(label_path, 'r', encoding='utf-8') as f:
|
||
lines = f.readlines()
|
||
|
||
for line_idx, line in enumerate(lines):
|
||
parts = line.strip().split()
|
||
if len(parts) < 7:
|
||
continue
|
||
|
||
class_id = int(parts[0])
|
||
points = list(map(float, parts[1:]))
|
||
|
||
# 转换为像素坐标
|
||
pixel_points = []
|
||
for i in range(0, len(points), 2):
|
||
x = int(points[i] * w)
|
||
y = int(points[i+1] * h)
|
||
pixel_points.append([x, y])
|
||
|
||
# 绘制多边形
|
||
pts = np.array(pixel_points, np.int32)
|
||
# 填充多边形(半透明)
|
||
overlay = img.copy()
|
||
cv2.fillPoly(overlay, [pts], (0, 255, 0))
|
||
cv2.addWeighted(overlay, 0.3, img, 0.7, 0, img)
|
||
# 绘制多边形边界
|
||
cv2.polylines(img, [pts], True, (0, 255, 0), 2)
|
||
|
||
# 添加类别标签
|
||
if len(pixel_points) > 0:
|
||
cv2.putText(img, f"Class {class_id}",
|
||
(pixel_points[0][0], pixel_points[0][1] - 10),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
|
||
|
||
# 显示结果
|
||
cv2.imshow('分割标注验证', img)
|
||
print("按任意键关闭窗口...")
|
||
cv2.waitKey(0)
|
||
cv2.destroyAllWindows()
|
||
|
||
|
||
def batch_verify_labels(labels_dir, images_dir=None, sample_count=5):
|
||
"""
|
||
批量验证标注文件
|
||
|
||
Args:
|
||
labels_dir: 标注文件目录
|
||
images_dir: 图片目录(可选)
|
||
sample_count: 随机采样验证的文件数量(0表示全部)
|
||
"""
|
||
label_files = list(Path(labels_dir).glob("*.txt"))
|
||
|
||
if not label_files:
|
||
print(f"未找到标注文件: {labels_dir}")
|
||
return
|
||
|
||
if sample_count > 0 and len(label_files) > sample_count:
|
||
import random
|
||
label_files = random.sample(label_files, sample_count)
|
||
print(f"随机采样 {sample_count} 个文件进行验证...")
|
||
|
||
print(f"开始验证 {len(label_files)} 个标注文件...")
|
||
print("-" * 60)
|
||
|
||
valid_count = 0
|
||
error_count = 0
|
||
|
||
for label_file in label_files:
|
||
# 尝试找到对应的图片文件
|
||
img_path = None
|
||
if images_dir:
|
||
for ext in ['.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG']:
|
||
img_path = Path(images_dir) / (label_file.stem + ext)
|
||
if img_path.exists():
|
||
break
|
||
img_path = None
|
||
|
||
result = verify_yolo_seg_format(str(label_file), str(img_path) if img_path else None)
|
||
|
||
status = "✓" if result["valid"] else "✗"
|
||
print(f"{status} {label_file.name}")
|
||
print(f" 类别ID: {result['class_id']}, 顶点数: {result['num_points']}, 坐标数: {result['coords_count']}")
|
||
|
||
if result["errors"]:
|
||
print(f" 错误: {', '.join(result['errors'])}")
|
||
error_count += 1
|
||
elif result["warnings"]:
|
||
print(f" 警告: {', '.join(result['warnings'])}")
|
||
|
||
if result["valid"]:
|
||
valid_count += 1
|
||
|
||
print()
|
||
|
||
print("-" * 60)
|
||
print(f"验证完成: {valid_count}/{len(label_files)} 个文件格式正确")
|
||
if error_count > 0:
|
||
print(f"发现 {error_count} 个文件有错误")
|
||
|
||
|
||
def main():
|
||
"""主函数 - 示例用法"""
|
||
import sys
|
||
|
||
if len(sys.argv) > 1:
|
||
# 命令行模式
|
||
labels_dir = sys.argv[1]
|
||
images_dir = sys.argv[2] if len(sys.argv) > 2 else None
|
||
batch_verify_labels(labels_dir, images_dir)
|
||
else:
|
||
# 交互模式 - 验证搭电设备数据集
|
||
print("验证搭电设备数据集的标注格式...")
|
||
labels_dir = "datasets/搭电设备/train/labels"
|
||
images_dir = "datasets/搭电设备/train/images"
|
||
|
||
if os.path.exists(labels_dir):
|
||
batch_verify_labels(labels_dir, images_dir, sample_count=10)
|
||
else:
|
||
print(f"目录不存在: {labels_dir}")
|
||
print("\n使用方法:")
|
||
print(" python verify_seg_labels.py <labels_dir> [images_dir]")
|
||
print("\n示例:")
|
||
print(" python verify_seg_labels.py datasets/搭电设备/train/labels datasets/搭电设备/train/images")
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|
||
|