92 lines
3.1 KiB
Python
92 lines
3.1 KiB
Python
import os
|
|
import xml.etree.ElementTree as ET
|
|
|
|
def convert_voc_to_yolo(xml_path, output_dir):
|
|
"""将VOC格式XML转换为YOLO格式TXT"""
|
|
tree = ET.parse(xml_path)
|
|
root = tree.getroot()
|
|
|
|
# 获取图像尺寸
|
|
size = root.find('size')
|
|
if size is None:
|
|
return
|
|
|
|
width_elem = size.find('width')
|
|
height_elem = size.find('height')
|
|
if width_elem is None or height_elem is None or width_elem.text is None or height_elem.text is None:
|
|
return
|
|
|
|
try:
|
|
width = int(width_elem.text)
|
|
height = int(height_elem.text)
|
|
except (ValueError, TypeError):
|
|
return
|
|
|
|
# 类别映射
|
|
class_map = {'猫': 0, '狗': 1}
|
|
|
|
# 准备输出内容
|
|
yolo_lines = []
|
|
|
|
for obj in root.findall('object'):
|
|
name_elem = obj.find('name')
|
|
if name_elem is None or name_elem.text is None:
|
|
continue
|
|
|
|
class_name = name_elem.text
|
|
class_id = class_map.get(class_name, -1)
|
|
if class_id == -1:
|
|
continue
|
|
|
|
bndbox = obj.find('bndbox')
|
|
if bndbox is None:
|
|
continue
|
|
|
|
try:
|
|
xmin_elem = bndbox.find('xmin')
|
|
xmin = int(xmin_elem.text) if xmin_elem is not None and xmin_elem.text is not None else 0
|
|
|
|
ymin_elem = bndbox.find('ymin')
|
|
ymin = int(ymin_elem.text) if ymin_elem is not None and ymin_elem.text is not None else 0
|
|
|
|
xmax_elem = bndbox.find('xmax')
|
|
xmax = int(xmax_elem.text) if xmax_elem is not None and xmax_elem.text is not None else 0
|
|
|
|
ymax_elem = bndbox.find('ymax')
|
|
ymax = int(ymax_elem.text) if ymax_elem is not None and ymax_elem.text is not None else 0
|
|
except (ValueError, TypeError):
|
|
continue
|
|
|
|
# 转换为YOLO格式
|
|
x_center = (xmin + xmax) / 2 / width
|
|
y_center = (ymin + ymax) / 2 / height
|
|
box_width = (xmax - xmin) / width
|
|
box_height = (ymax - ymin) / height
|
|
|
|
yolo_lines.append(f"{class_id} {x_center:.6f} {y_center:.6f} {box_width:.6f} {box_height:.6f}")
|
|
|
|
# 写入TXT文件
|
|
if yolo_lines:
|
|
txt_filename = os.path.splitext(os.path.basename(xml_path))[0] + '.txt'
|
|
txt_path = os.path.join(output_dir, txt_filename)
|
|
|
|
with open(txt_path, 'w', encoding='utf-8') as f:
|
|
f.write('\n'.join(yolo_lines))
|
|
|
|
def batch_convert(input_dir, output_dir):
|
|
"""批量转换目录中的所有XML文件"""
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
for filename in os.listdir(input_dir):
|
|
if filename.endswith('.xml'):
|
|
xml_path = os.path.join(input_dir, filename)
|
|
convert_voc_to_yolo(xml_path, output_dir)
|
|
|
|
if __name__ == "__main__":
|
|
# 设置输入输出目录
|
|
input_dir = r'F:\myprojects\啾啾救援识别系统\JJCarDetection\datasets\猫狗\train\labels'
|
|
output_dir = input_dir # 同一目录输出
|
|
|
|
print(f"开始转换: {input_dir}")
|
|
batch_convert(input_dir, output_dir)
|
|
print(f"转换完成,结果保存在: {output_dir}") |