jjsos_JJdetection/verify_seg_labels.py

231 lines
7.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
验证YOLO分割标注格式的工具脚本
用于检查从CVAT导出的多边形标注是否正确
"""
import os
import cv2
import numpy as np
from pathlib import Path
def verify_yolo_seg_format(label_path, img_path=None, show_visualization=False):
"""
验证YOLO分割标注格式
Args:
label_path: 标注文件路径
img_path: 对应的图片路径(可选,用于可视化)
show_visualization: 是否显示可视化结果
Returns:
dict: 验证结果信息
"""
result = {
"valid": False,
"class_id": None,
"num_points": 0,
"coords_count": 0,
"errors": [],
"warnings": []
}
if not os.path.exists(label_path):
result["errors"].append(f"文件不存在: {label_path}")
return result
try:
with open(label_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
except Exception as e:
result["errors"].append(f"读取文件失败: {e}")
return result
if not lines:
result["warnings"].append("文件为空")
return result
# 解析第一行(通常一个对象一行)
line = lines[0].strip()
if not line:
result["warnings"].append("第一行为空")
return result
parts = line.split()
if len(parts) < 7: # 至少需要 class_id + 3个点(6个值)
result["errors"].append(f"格式错误: 坐标点数量不足 (需要至少3个点当前只有{len(parts)-1}个值)")
return result
try:
class_id = int(parts[0])
coords = list(map(float, parts[1:]))
except ValueError as e:
result["errors"].append(f"格式错误: 无法解析数字 - {e}")
return result
# 检查坐标数量(必须是偶数)
if len(coords) % 2 != 0:
result["errors"].append(f"格式错误: 坐标数量不是偶数 ({len(coords)}个值)")
return result
num_points = len(coords) // 2
# 检查坐标范围
invalid_coords = [c for c in coords if c < 0 or c > 1]
if invalid_coords:
result["warnings"].append(f"坐标超出[0,1]范围: {len(invalid_coords)}个值")
result["valid"] = True
result["class_id"] = class_id
result["num_points"] = num_points
result["coords_count"] = len(coords)
# 可视化验证
if show_visualization and img_path and os.path.exists(img_path):
visualize_seg_label(img_path, label_path)
return result
def visualize_seg_label(img_path, label_path):
"""
可视化分割标注,在图片上绘制多边形
Args:
img_path: 图片路径
label_path: 标注文件路径
"""
img = cv2.imread(img_path)
if img is None:
print(f"无法读取图片: {img_path}")
return
h, w = img.shape[:2]
with open(label_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
for line_idx, line in enumerate(lines):
parts = line.strip().split()
if len(parts) < 7:
continue
class_id = int(parts[0])
points = list(map(float, parts[1:]))
# 转换为像素坐标
pixel_points = []
for i in range(0, len(points), 2):
x = int(points[i] * w)
y = int(points[i+1] * h)
pixel_points.append([x, y])
# 绘制多边形
pts = np.array(pixel_points, np.int32)
# 填充多边形(半透明)
overlay = img.copy()
cv2.fillPoly(overlay, [pts], (0, 255, 0))
cv2.addWeighted(overlay, 0.3, img, 0.7, 0, img)
# 绘制多边形边界
cv2.polylines(img, [pts], True, (0, 255, 0), 2)
# 添加类别标签
if len(pixel_points) > 0:
cv2.putText(img, f"Class {class_id}",
(pixel_points[0][0], pixel_points[0][1] - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
# 显示结果
cv2.imshow('分割标注验证', img)
print("按任意键关闭窗口...")
cv2.waitKey(0)
cv2.destroyAllWindows()
def batch_verify_labels(labels_dir, images_dir=None, sample_count=5):
"""
批量验证标注文件
Args:
labels_dir: 标注文件目录
images_dir: 图片目录(可选)
sample_count: 随机采样验证的文件数量0表示全部
"""
label_files = list(Path(labels_dir).glob("*.txt"))
if not label_files:
print(f"未找到标注文件: {labels_dir}")
return
if sample_count > 0 and len(label_files) > sample_count:
import random
label_files = random.sample(label_files, sample_count)
print(f"随机采样 {sample_count} 个文件进行验证...")
print(f"开始验证 {len(label_files)} 个标注文件...")
print("-" * 60)
valid_count = 0
error_count = 0
for label_file in label_files:
# 尝试找到对应的图片文件
img_path = None
if images_dir:
for ext in ['.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG']:
img_path = Path(images_dir) / (label_file.stem + ext)
if img_path.exists():
break
img_path = None
result = verify_yolo_seg_format(str(label_file), str(img_path) if img_path else None)
status = "" if result["valid"] else ""
print(f"{status} {label_file.name}")
print(f" 类别ID: {result['class_id']}, 顶点数: {result['num_points']}, 坐标数: {result['coords_count']}")
if result["errors"]:
print(f" 错误: {', '.join(result['errors'])}")
error_count += 1
elif result["warnings"]:
print(f" 警告: {', '.join(result['warnings'])}")
if result["valid"]:
valid_count += 1
print()
print("-" * 60)
print(f"验证完成: {valid_count}/{len(label_files)} 个文件格式正确")
if error_count > 0:
print(f"发现 {error_count} 个文件有错误")
def main():
"""主函数 - 示例用法"""
import sys
if len(sys.argv) > 1:
# 命令行模式
labels_dir = sys.argv[1]
images_dir = sys.argv[2] if len(sys.argv) > 2 else None
batch_verify_labels(labels_dir, images_dir)
else:
# 交互模式 - 验证搭电设备数据集
print("验证搭电设备数据集的标注格式...")
labels_dir = "datasets/搭电设备/train/labels"
images_dir = "datasets/搭电设备/train/images"
if os.path.exists(labels_dir):
batch_verify_labels(labels_dir, images_dir, sample_count=10)
else:
print(f"目录不存在: {labels_dir}")
print("\n使用方法:")
print(" python verify_seg_labels.py <labels_dir> [images_dir]")
print("\n示例:")
print(" python verify_seg_labels.py datasets/搭电设备/train/labels datasets/搭电设备/train/images")
if __name__ == '__main__':
main()