#!/usr/bin/env python3 """ ROS bag to WebDataset converter for DAVIS dataset Extracts: grayscale images, IMU sequence, ground truth poses and velocities Usage: python convert_bag_to_webdataset.py --bag --output --name """ import argparse import json import os import sys from pathlib import Path from typing import Dict, List, Tuple, Optional import numpy as np import rosbag from cv_bridge import CvBridge import cv2 import webdataset as wds from tqdm import tqdm from scipy.spatial.transform import Rotation as R class BagToWebDataset: def __init__(self, bag_path: str, output_dir: str, dataset_name: str, shard_size: int = 2000, image_width: int = 320, image_height: int = 240): self.bag_path = Path(bag_path) self.output_dir = Path(output_dir) / dataset_name self.dataset_name = dataset_name self.shard_size = shard_size self.image_width = image_width self.image_height = image_height self.bridge = CvBridge() # Data containers self.images: List[Tuple[float, np.ndarray]] = [] # (timestamp, image) self.imu_timestamps: List[float] = [] self.imu_acc: List[np.ndarray] = [] # (ax, ay, az) self.imu_gyro: List[np.ndarray] = [] # (gx, gy, gz) self.gt_timestamps: List[float] = [] self.gt_poses: List[np.ndarray] = [] # (x, y, z, qx, qy, qz, qw) self.gt_velocities: List[np.ndarray] = [] # (vx, vy, vz, wx, wy, wz) def extract_all_data(self): """Extract all data from ROS bag""" print(f"Opening bag: {self.bag_path}") bag = rosbag.Bag(str(self.bag_path), 'r') # Count messages for progress bar topic_counts = {topic: bag.get_message_count(topic) for topic in ['/dvs/image_raw', '/dvs/imu', '/groundtruth/odometry']} total_msgs = sum(topic_counts.values()) print(f"Topics: {topic_counts}") with tqdm(total=total_msgs, desc="Extracting messages") as pbar: for topic, msg, t in bag.read_messages(topics=['/dvs/image_raw', '/dvs/imu', '/groundtruth/odometry']): if topic == '/dvs/image_raw': self._process_image(msg, t) # 传入 t elif topic == '/dvs/imu': self._process_imu(msg, t) # 传入 t elif topic == '/groundtruth/odometry': self._process_odometry(msg, t) # 传入 t pbar.update(1) bag.close() # Post-processing: compute velocities from poses if not directly available self._ensure_velocities() # Print statistics print(f"\nExtraction completed:") print(f" Images: {len(self.images)}") print(f" IMU messages: {len(self.imu_timestamps)}") print(f" Ground truth poses: {len(self.gt_timestamps)}") print(f" Ground truth velocities: {len(self.gt_velocities)}") def crop_to_gt_time_range(self): """裁剪所有数据,只保留 GT 时间范围内的部分""" if len(self.gt_timestamps) == 0: print("Warning: No GT data found, skipping crop") return gt_start = min(self.gt_timestamps) gt_end = max(self.gt_timestamps) print(f"\nCropping to GT time range: {gt_start:.3f} - {gt_end:.3f} ({gt_end - gt_start:.1f}s)") # 裁剪图像 original_img_count = len(self.images) self.images = [(ts, img) for ts, img in self.images if gt_start <= ts <= gt_end] print(f" Images: {original_img_count} -> {len(self.images)}") # 裁剪 IMU original_imu_count = len(self.imu_timestamps) imu_filtered = [(ts, acc, gyro) for ts, acc, gyro in zip(self.imu_timestamps, self.imu_acc, self.imu_gyro) if gt_start <= ts <= gt_end] if imu_filtered: self.imu_timestamps = [item[0] for item in imu_filtered] self.imu_acc = [item[1] for item in imu_filtered] self.imu_gyro = [item[2] for item in imu_filtered] print(f" IMU: {original_imu_count} -> {len(self.imu_timestamps)}") # GT 数据本身已经在范围内,不需要裁剪 print(f" GT: {len(self.gt_timestamps)} (unchanged)") def _process_image(self, msg, t): """Process grayscale image message using system time""" try: # Convert ROS image to OpenCV format cv_img = self.bridge.imgmsg_to_cv2(msg, desired_encoding='mono8') # Resize if needed if self.image_width and self.image_height: cv_img = cv2.resize(cv_img, (self.image_width, self.image_height), interpolation=cv2.INTER_LINEAR) # 使用系统物理时间,而不是 msg.header.stamp timestamp = t.to_sec() self.images.append((timestamp, cv_img)) except Exception as e: print(f"Error processing image: {e}") def _process_imu(self, msg, t): """Process IMU message using system time""" timestamp = t.to_sec() # 使用系统物理时间 # Linear acceleration (m/s^2) acc = np.array([msg.linear_acceleration.x, msg.linear_acceleration.y, msg.linear_acceleration.z], dtype=np.float32) # Angular velocity (rad/s) gyro = np.array([msg.angular_velocity.x, msg.angular_velocity.y, msg.angular_velocity.z], dtype=np.float32) self.imu_timestamps.append(timestamp) self.imu_acc.append(acc) self.imu_gyro.append(gyro) def _process_odometry(self, msg, t): """Process ground truth odometry using system time""" timestamp = t.to_sec() # 使用系统物理时间 # Position (x, y, z) pos = np.array([msg.pose.pose.position.x, msg.pose.pose.position.y, msg.pose.pose.position.z], dtype=np.float32) # Orientation (qx, qy, qz, qw) - already normalized quat = np.array([msg.pose.pose.orientation.x, msg.pose.pose.orientation.y, msg.pose.pose.orientation.z, msg.pose.pose.orientation.w], dtype=np.float32) pose = np.concatenate([pos, quat]) self.gt_timestamps.append(timestamp) self.gt_poses.append(pose) # Velocity: always compute from pose differences in post-processing vel = None self.gt_velocities.append(vel) def _ensure_velocities(self): # 数据集中 twist 数据为 0 直接利用时间戳差值 # """Compute velocities from pose differences if not directly available""" # # Check if velocities are missing # missing_velocities = any(v is None for v in self.gt_velocities) # if not missing_velocities: # return print("Computing velocities from pose differences...") computed_velocities = [] for i in range(len(self.gt_timestamps)): if i == 0: # Use forward difference for first frame if len(self.gt_timestamps) > 1: dt = self.gt_timestamps[1] - self.gt_timestamps[0] if dt > 0: # Linear velocity v_lin = (self.gt_poses[1][:3] - self.gt_poses[0][:3]) / dt # Angular velocity (from quaternion difference) q0 = self.gt_poses[0][3:7] q1 = self.gt_poses[1][3:7] dq = R.from_quat(q1) * R.from_quat(q0).inv() v_ang = dq.as_rotvec() / dt computed_velocities.append(np.concatenate([v_lin, v_ang])) else: computed_velocities.append(np.zeros(6, dtype=np.float32)) else: computed_velocities.append(np.zeros(6, dtype=np.float32)) else: # Use backward difference dt = self.gt_timestamps[i] - self.gt_timestamps[i-1] if dt > 0: v_lin = (self.gt_poses[i][:3] - self.gt_poses[i-1][:3]) / dt q0 = self.gt_poses[i-1][3:7] q1 = self.gt_poses[i][3:7] dq = R.from_quat(q1) * R.from_quat(q0).inv() v_ang = dq.as_rotvec() / dt computed_velocities.append(np.concatenate([v_lin, v_ang])) else: computed_velocities.append(np.zeros(6, dtype=np.float32)) # Replace missing velocities for i in range(len(self.gt_velocities)): if self.gt_velocities[i] is None: self.gt_velocities[i] = computed_velocities[i] def save_imu_sequence(self): """Save IMU sequence as NPZ file""" imu_data = { 'timestamps': np.array(self.imu_timestamps, dtype=np.float64), 'accelerations': np.array(self.imu_acc, dtype=np.float32), 'angular_velocities': np.array(self.imu_gyro, dtype=np.float32) } imu_path = self.output_dir / 'imu_sequence.npz' imu_path.parent.mkdir(parents=True, exist_ok=True) np.savez_compressed(imu_path, **imu_data) print(f"Saved IMU sequence: {imu_path}") return imu_path def align_ground_truth_to_images(self) -> List[Tuple[float, np.ndarray, np.ndarray, np.ndarray]]: """Align ground truth (pose + velocity) to each image using nearest timestamp""" aligned_gt = [] gt_timestamps = np.array(self.gt_timestamps) gt_poses = np.array(self.gt_poses) gt_vels = np.array(self.gt_velocities) for img_ts, img in tqdm(self.images, desc="Aligning ground truth to images"): idx = np.argmin(np.abs(gt_timestamps - img_ts)) time_diff = abs(gt_timestamps[idx] - img_ts) if time_diff < 0.1: aligned_gt.append((img_ts, img, gt_poses[idx], gt_vels[idx])) # 保存图像 return aligned_gt def save_as_webdataset(self, aligned_gt: List[Tuple[float, np.ndarray, np.ndarray, np.ndarray]]): """Save images and aligned ground truth as WebDataset tar files""" num_shards = (len(aligned_gt) + self.shard_size - 1) // self.shard_size print(f"Saving {len(aligned_gt)} samples into {num_shards} shards...") for shard_idx in range(num_shards): start_idx = shard_idx * self.shard_size end_idx = min((shard_idx + 1) * self.shard_size, len(aligned_gt)) tar_path = self.output_dir / f'shard_{shard_idx:04d}.tar' with wds.TarWriter(str(tar_path)) as sink: for local_idx, (img_ts, img, pose, vel) in enumerate(aligned_gt): # Encode image as JPEG _, img_encoded = cv2.imencode('.jpg', img, [cv2.IMWRITE_JPEG_QUALITY, 85]) img_bytes = img_encoded.tobytes() # Prepare metadata sample_key = f'frame_{local_idx:08d}' # Write to tar sink.write({ '__key__': sample_key, 'jpg': img_bytes, 'ts': np.array([img_ts], dtype=np.float64).tobytes(), 'pose': pose.astype(np.float32).tobytes(), 'vel': vel.astype(np.float32).tobytes() }) print(f" Saved {tar_path} ({end_idx - start_idx} samples)") def save_metadata(self): """Save dataset metadata""" metadata = { 'dataset_name': self.dataset_name, 'source_bag': str(self.bag_path), 'num_images': len(self.images), 'num_imu_messages': len(self.imu_timestamps), 'num_ground_truth': len(self.gt_timestamps), 'image_size': [self.image_width, self.image_height], 'imu_frequency_hz': len(self.imu_timestamps) / (self.imu_timestamps[-1] - self.imu_timestamps[0]) if len(self.imu_timestamps) > 1 else 0, 'camera_frequency_hz': len(self.images) / (self.images[-1][0] - self.images[0][0]) if len(self.images) > 1 else 0, 'gt_frequency_hz': len(self.gt_timestamps) / (self.gt_timestamps[-1] - self.gt_timestamps[0]) if len(self.gt_timestamps) > 1 else 0, 'coordinate_system': 'horizontal (z aligned with gravity, assumed from GT)', 'velocity_dimensions': 6, # (vx, vy, vz, wx, wy, wz) } metadata_path = self.output_dir / 'metadata.json' with open(metadata_path, 'w') as f: json.dump(metadata, f, indent=2) print(f"Saved metadata: {metadata_path}") def convert(self): """Main conversion pipeline""" print(f"\n{'='*60}") print(f"Converting: {self.bag_path.name}") print(f"Output: {self.output_dir}") print(f"{'='*60}\n") # Step 1: Extract all data from bag self.extract_all_data() # 裁剪掉无 GT 的时间段 self.crop_to_gt_time_range() # Step 2: Save IMU sequence self.save_imu_sequence() # # Step 3: Align ground truth to images aligned_gt = self.align_ground_truth_to_images() if len(aligned_gt) == 0: print("Error: No aligned ground truth found!") sys.exit(1) # # Step 4: Save as WebDataset self.save_as_webdataset(aligned_gt) # # Step 5: Save metadata self.save_metadata() self.diagnose_timestamps() print(f"\n✅ Conversion completed for {self.bag_path.name}") def diagnose_timestamps(self): """Print timestamp ranges for debugging""" img_timestamps = [t for t, _ in self.images] gt_timestamps = self.gt_timestamps print(f"Image timestamps: {min(img_timestamps):.3f} - {max(img_timestamps):.3f}") print(f"GT timestamps: {min(gt_timestamps):.3f} - {max(gt_timestamps):.3f}") print(f"Image duration: {max(img_timestamps) - min(img_timestamps):.3f}s") print(f"GT duration: {max(gt_timestamps) - min(gt_timestamps):.3f}s") # Check if there's a constant offset if len(img_timestamps) > 0 and len(gt_timestamps) > 0: offset = gt_timestamps[0] - img_timestamps[0] print(f"Initial offset (first GT - first image): {offset:.3f}s") def main(): parser = argparse.ArgumentParser(description='Convert ROS bag to WebDataset format') parser.add_argument('--bag', type=str, required=True, help='Path to ROS bag file') parser.add_argument('--output', type=str, default='./dataset', help='Output directory') parser.add_argument('--name', type=str, default=None, help='Dataset name (default: bag filename without extension)') parser.add_argument('--shard_size', type=int, default=2000, help='Number of samples per shard') parser.add_argument('--width', type=int, default=320, help='Image width (resize)') parser.add_argument('--height', type=int, default=240, help='Image height (resize)') args = parser.parse_args() # Validate inputs if not os.path.exists(args.bag): print(f"Error: Bag file not found: {args.bag}") sys.exit(1) # Set dataset name if args.name is None: args.name = Path(args.bag).stem # Run conversion converter = BagToWebDataset( bag_path=args.bag, output_dir=args.output, dataset_name=args.name, shard_size=args.shard_size, image_width=args.width, image_height=args.height ) converter.convert() if __name__ == '__main__': main()