Files
uzh-fpv-sv-test/rosbag2wds.py
2026-05-29 18:49:01 +08:00

396 lines
16 KiB
Python

#!/usr/bin/env python3
"""
ROS bag to WebDataset converter for DAVIS dataset
Extracts: grayscale images, IMU sequence, ground truth poses and velocities
Usage:
python convert_bag_to_webdataset.py --bag <path_to.bag> --output <output_dir> --name <dataset_name>
"""
import argparse
import json
import os
import sys
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import numpy as np
import rosbag
from cv_bridge import CvBridge
import cv2
import webdataset as wds
from tqdm import tqdm
from scipy.spatial.transform import Rotation as R
class BagToWebDataset:
def __init__(self, bag_path: str, output_dir: str, dataset_name: str,
shard_size: int = 2000, image_width: int = 320, image_height: int = 240):
self.bag_path = Path(bag_path)
self.output_dir = Path(output_dir) / dataset_name
self.dataset_name = dataset_name
self.shard_size = shard_size
self.image_width = image_width
self.image_height = image_height
self.bridge = CvBridge()
# Data containers
self.images: List[Tuple[float, np.ndarray]] = [] # (timestamp, image)
self.imu_timestamps: List[float] = []
self.imu_acc: List[np.ndarray] = [] # (ax, ay, az)
self.imu_gyro: List[np.ndarray] = [] # (gx, gy, gz)
self.gt_timestamps: List[float] = []
self.gt_poses: List[np.ndarray] = [] # (x, y, z, qx, qy, qz, qw)
self.gt_velocities: List[np.ndarray] = [] # (vx, vy, vz, wx, wy, wz)
def extract_all_data(self):
"""Extract all data from ROS bag"""
print(f"Opening bag: {self.bag_path}")
bag = rosbag.Bag(str(self.bag_path), 'r')
# Count messages for progress bar
topic_counts = {topic: bag.get_message_count(topic) for topic in
['/dvs/image_raw', '/dvs/imu', '/groundtruth/odometry']}
total_msgs = sum(topic_counts.values())
print(f"Topics: {topic_counts}")
with tqdm(total=total_msgs, desc="Extracting messages") as pbar:
for topic, msg, t in bag.read_messages(topics=['/dvs/image_raw', '/dvs/imu', '/groundtruth/odometry']):
if topic == '/dvs/image_raw':
self._process_image(msg, t) # 传入 t
elif topic == '/dvs/imu':
self._process_imu(msg, t) # 传入 t
elif topic == '/groundtruth/odometry':
self._process_odometry(msg, t) # 传入 t
pbar.update(1)
bag.close()
# Post-processing: compute velocities from poses if not directly available
self._ensure_velocities()
# Print statistics
print(f"\nExtraction completed:")
print(f" Images: {len(self.images)}")
print(f" IMU messages: {len(self.imu_timestamps)}")
print(f" Ground truth poses: {len(self.gt_timestamps)}")
print(f" Ground truth velocities: {len(self.gt_velocities)}")
def crop_to_gt_time_range(self):
"""裁剪所有数据,只保留 GT 时间范围内的部分"""
if len(self.gt_timestamps) == 0:
print("Warning: No GT data found, skipping crop")
return
gt_start = min(self.gt_timestamps)
gt_end = max(self.gt_timestamps)
print(f"\nCropping to GT time range: {gt_start:.3f} - {gt_end:.3f} ({gt_end - gt_start:.1f}s)")
# 裁剪图像
original_img_count = len(self.images)
self.images = [(ts, img) for ts, img in self.images if gt_start <= ts <= gt_end]
print(f" Images: {original_img_count} -> {len(self.images)}")
# 裁剪 IMU
original_imu_count = len(self.imu_timestamps)
imu_filtered = [(ts, acc, gyro) for ts, acc, gyro
in zip(self.imu_timestamps, self.imu_acc, self.imu_gyro)
if gt_start <= ts <= gt_end]
if imu_filtered:
self.imu_timestamps = [item[0] for item in imu_filtered]
self.imu_acc = [item[1] for item in imu_filtered]
self.imu_gyro = [item[2] for item in imu_filtered]
print(f" IMU: {original_imu_count} -> {len(self.imu_timestamps)}")
# GT 数据本身已经在范围内,不需要裁剪
print(f" GT: {len(self.gt_timestamps)} (unchanged)")
def _process_image(self, msg, t):
"""Process grayscale image message using system time"""
try:
# Convert ROS image to OpenCV format
cv_img = self.bridge.imgmsg_to_cv2(msg, desired_encoding='mono8')
# Resize if needed
if self.image_width and self.image_height:
cv_img = cv2.resize(cv_img, (self.image_width, self.image_height),
interpolation=cv2.INTER_LINEAR)
# 使用系统物理时间,而不是 msg.header.stamp
timestamp = t.to_sec()
self.images.append((timestamp, cv_img))
except Exception as e:
print(f"Error processing image: {e}")
def _process_imu(self, msg, t):
"""Process IMU message using system time"""
timestamp = t.to_sec() # 使用系统物理时间
# Linear acceleration (m/s^2)
acc = np.array([msg.linear_acceleration.x,
msg.linear_acceleration.y,
msg.linear_acceleration.z], dtype=np.float32)
# Angular velocity (rad/s)
gyro = np.array([msg.angular_velocity.x,
msg.angular_velocity.y,
msg.angular_velocity.z], dtype=np.float32)
self.imu_timestamps.append(timestamp)
self.imu_acc.append(acc)
self.imu_gyro.append(gyro)
def _process_odometry(self, msg, t):
"""Process ground truth odometry using system time"""
timestamp = t.to_sec() # 使用系统物理时间
# Position (x, y, z)
pos = np.array([msg.pose.pose.position.x,
msg.pose.pose.position.y,
msg.pose.pose.position.z], dtype=np.float32)
# Orientation (qx, qy, qz, qw) - already normalized
quat = np.array([msg.pose.pose.orientation.x,
msg.pose.pose.orientation.y,
msg.pose.pose.orientation.z,
msg.pose.pose.orientation.w], dtype=np.float32)
pose = np.concatenate([pos, quat])
self.gt_timestamps.append(timestamp)
self.gt_poses.append(pose)
# Velocity: always compute from pose differences in post-processing
vel = None
self.gt_velocities.append(vel)
def _ensure_velocities(self):
# 数据集中 twist 数据为 0 直接利用时间戳差值
# """Compute velocities from pose differences if not directly available"""
# # Check if velocities are missing
# missing_velocities = any(v is None for v in self.gt_velocities)
# if not missing_velocities:
# return
print("Computing velocities from pose differences...")
computed_velocities = []
for i in range(len(self.gt_timestamps)):
if i == 0:
# Use forward difference for first frame
if len(self.gt_timestamps) > 1:
dt = self.gt_timestamps[1] - self.gt_timestamps[0]
if dt > 0:
# Linear velocity
v_lin = (self.gt_poses[1][:3] - self.gt_poses[0][:3]) / dt
# Angular velocity (from quaternion difference)
q0 = self.gt_poses[0][3:7]
q1 = self.gt_poses[1][3:7]
dq = R.from_quat(q1) * R.from_quat(q0).inv()
v_ang = dq.as_rotvec() / dt
computed_velocities.append(np.concatenate([v_lin, v_ang]))
else:
computed_velocities.append(np.zeros(6, dtype=np.float32))
else:
computed_velocities.append(np.zeros(6, dtype=np.float32))
else:
# Use backward difference
dt = self.gt_timestamps[i] - self.gt_timestamps[i-1]
if dt > 0:
v_lin = (self.gt_poses[i][:3] - self.gt_poses[i-1][:3]) / dt
q0 = self.gt_poses[i-1][3:7]
q1 = self.gt_poses[i][3:7]
dq = R.from_quat(q1) * R.from_quat(q0).inv()
v_ang = dq.as_rotvec() / dt
computed_velocities.append(np.concatenate([v_lin, v_ang]))
else:
computed_velocities.append(np.zeros(6, dtype=np.float32))
# Replace missing velocities
for i in range(len(self.gt_velocities)):
if self.gt_velocities[i] is None:
self.gt_velocities[i] = computed_velocities[i]
def save_imu_sequence(self):
"""Save IMU sequence as NPZ file"""
imu_data = {
'timestamps': np.array(self.imu_timestamps, dtype=np.float64),
'accelerations': np.array(self.imu_acc, dtype=np.float32),
'angular_velocities': np.array(self.imu_gyro, dtype=np.float32)
}
imu_path = self.output_dir / 'imu_sequence.npz'
imu_path.parent.mkdir(parents=True, exist_ok=True)
np.savez_compressed(imu_path, **imu_data)
print(f"Saved IMU sequence: {imu_path}")
return imu_path
def align_ground_truth_to_images(self) -> List[Tuple[float, np.ndarray, np.ndarray, np.ndarray]]:
"""Align ground truth (pose + velocity) to each image using nearest timestamp"""
aligned_gt = []
gt_timestamps = np.array(self.gt_timestamps)
gt_poses = np.array(self.gt_poses)
gt_vels = np.array(self.gt_velocities)
for img_ts, img in tqdm(self.images, desc="Aligning ground truth to images"):
idx = np.argmin(np.abs(gt_timestamps - img_ts))
time_diff = abs(gt_timestamps[idx] - img_ts)
if time_diff < 0.1:
aligned_gt.append((img_ts, img, gt_poses[idx], gt_vels[idx])) # 保存图像
return aligned_gt
def save_as_webdataset(self, aligned_gt: List[Tuple[float, np.ndarray, np.ndarray, np.ndarray]]):
"""Save images and aligned ground truth as WebDataset tar files"""
num_shards = (len(aligned_gt) + self.shard_size - 1) // self.shard_size
print(f"Saving {len(aligned_gt)} samples into {num_shards} shards...")
for shard_idx in range(num_shards):
start_idx = shard_idx * self.shard_size
end_idx = min((shard_idx + 1) * self.shard_size, len(aligned_gt))
tar_path = self.output_dir / f'shard_{shard_idx:04d}.tar'
with wds.TarWriter(str(tar_path)) as sink:
for local_idx, (img_ts, img, pose, vel) in enumerate(aligned_gt):
# Encode image as JPEG
_, img_encoded = cv2.imencode('.jpg', img,
[cv2.IMWRITE_JPEG_QUALITY, 85])
img_bytes = img_encoded.tobytes()
# Prepare metadata
sample_key = f'frame_{local_idx:08d}'
# Write to tar
sink.write({
'__key__': sample_key,
'jpg': img_bytes,
'ts': np.array([img_ts], dtype=np.float64).tobytes(),
'pose': pose.astype(np.float32).tobytes(),
'vel': vel.astype(np.float32).tobytes()
})
print(f" Saved {tar_path} ({end_idx - start_idx} samples)")
def save_metadata(self):
"""Save dataset metadata"""
metadata = {
'dataset_name': self.dataset_name,
'source_bag': str(self.bag_path),
'num_images': len(self.images),
'num_imu_messages': len(self.imu_timestamps),
'num_ground_truth': len(self.gt_timestamps),
'image_size': [self.image_width, self.image_height],
'imu_frequency_hz': len(self.imu_timestamps) / (self.imu_timestamps[-1] - self.imu_timestamps[0]) if len(self.imu_timestamps) > 1 else 0,
'camera_frequency_hz': len(self.images) / (self.images[-1][0] - self.images[0][0]) if len(self.images) > 1 else 0,
'gt_frequency_hz': len(self.gt_timestamps) / (self.gt_timestamps[-1] - self.gt_timestamps[0]) if len(self.gt_timestamps) > 1 else 0,
'coordinate_system': 'horizontal (z aligned with gravity, assumed from GT)',
'velocity_dimensions': 6, # (vx, vy, vz, wx, wy, wz)
}
metadata_path = self.output_dir / 'metadata.json'
with open(metadata_path, 'w') as f:
json.dump(metadata, f, indent=2)
print(f"Saved metadata: {metadata_path}")
def convert(self):
"""Main conversion pipeline"""
print(f"\n{'='*60}")
print(f"Converting: {self.bag_path.name}")
print(f"Output: {self.output_dir}")
print(f"{'='*60}\n")
# Step 1: Extract all data from bag
self.extract_all_data()
# 裁剪掉无 GT 的时间段
self.crop_to_gt_time_range()
# Step 2: Save IMU sequence
self.save_imu_sequence()
# # Step 3: Align ground truth to images
aligned_gt = self.align_ground_truth_to_images()
if len(aligned_gt) == 0:
print("Error: No aligned ground truth found!")
sys.exit(1)
# # Step 4: Save as WebDataset
self.save_as_webdataset(aligned_gt)
# # Step 5: Save metadata
self.save_metadata()
self.diagnose_timestamps()
print(f"\n✅ Conversion completed for {self.bag_path.name}")
def diagnose_timestamps(self):
"""Print timestamp ranges for debugging"""
img_timestamps = [t for t, _ in self.images]
gt_timestamps = self.gt_timestamps
print(f"Image timestamps: {min(img_timestamps):.3f} - {max(img_timestamps):.3f}")
print(f"GT timestamps: {min(gt_timestamps):.3f} - {max(gt_timestamps):.3f}")
print(f"Image duration: {max(img_timestamps) - min(img_timestamps):.3f}s")
print(f"GT duration: {max(gt_timestamps) - min(gt_timestamps):.3f}s")
# Check if there's a constant offset
if len(img_timestamps) > 0 and len(gt_timestamps) > 0:
offset = gt_timestamps[0] - img_timestamps[0]
print(f"Initial offset (first GT - first image): {offset:.3f}s")
def main():
parser = argparse.ArgumentParser(description='Convert ROS bag to WebDataset format')
parser.add_argument('--bag', type=str, required=True, help='Path to ROS bag file')
parser.add_argument('--output', type=str, default='./dataset', help='Output directory')
parser.add_argument('--name', type=str, default=None, help='Dataset name (default: bag filename without extension)')
parser.add_argument('--shard_size', type=int, default=2000, help='Number of samples per shard')
parser.add_argument('--width', type=int, default=320, help='Image width (resize)')
parser.add_argument('--height', type=int, default=240, help='Image height (resize)')
args = parser.parse_args()
# Validate inputs
if not os.path.exists(args.bag):
print(f"Error: Bag file not found: {args.bag}")
sys.exit(1)
# Set dataset name
if args.name is None:
args.name = Path(args.bag).stem
# Run conversion
converter = BagToWebDataset(
bag_path=args.bag,
output_dir=args.output,
dataset_name=args.name,
shard_size=args.shard_size,
image_width=args.width,
image_height=args.height
)
converter.convert()
if __name__ == '__main__':
main()