396 lines
16 KiB
Python
396 lines
16 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
ROS bag to WebDataset converter for DAVIS dataset
|
|
Extracts: grayscale images, IMU sequence, ground truth poses and velocities
|
|
|
|
Usage:
|
|
python convert_bag_to_webdataset.py --bag <path_to.bag> --output <output_dir> --name <dataset_name>
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import Dict, List, Tuple, Optional
|
|
import numpy as np
|
|
import rosbag
|
|
from cv_bridge import CvBridge
|
|
import cv2
|
|
import webdataset as wds
|
|
from tqdm import tqdm
|
|
from scipy.spatial.transform import Rotation as R
|
|
|
|
|
|
class BagToWebDataset:
|
|
def __init__(self, bag_path: str, output_dir: str, dataset_name: str,
|
|
shard_size: int = 2000, image_width: int = 320, image_height: int = 240):
|
|
self.bag_path = Path(bag_path)
|
|
self.output_dir = Path(output_dir) / dataset_name
|
|
self.dataset_name = dataset_name
|
|
self.shard_size = shard_size
|
|
self.image_width = image_width
|
|
self.image_height = image_height
|
|
|
|
self.bridge = CvBridge()
|
|
|
|
# Data containers
|
|
self.images: List[Tuple[float, np.ndarray]] = [] # (timestamp, image)
|
|
self.imu_timestamps: List[float] = []
|
|
self.imu_acc: List[np.ndarray] = [] # (ax, ay, az)
|
|
self.imu_gyro: List[np.ndarray] = [] # (gx, gy, gz)
|
|
self.gt_timestamps: List[float] = []
|
|
self.gt_poses: List[np.ndarray] = [] # (x, y, z, qx, qy, qz, qw)
|
|
self.gt_velocities: List[np.ndarray] = [] # (vx, vy, vz, wx, wy, wz)
|
|
|
|
def extract_all_data(self):
|
|
"""Extract all data from ROS bag"""
|
|
print(f"Opening bag: {self.bag_path}")
|
|
bag = rosbag.Bag(str(self.bag_path), 'r')
|
|
|
|
# Count messages for progress bar
|
|
topic_counts = {topic: bag.get_message_count(topic) for topic in
|
|
['/dvs/image_raw', '/dvs/imu', '/groundtruth/odometry']}
|
|
total_msgs = sum(topic_counts.values())
|
|
|
|
print(f"Topics: {topic_counts}")
|
|
|
|
with tqdm(total=total_msgs, desc="Extracting messages") as pbar:
|
|
for topic, msg, t in bag.read_messages(topics=['/dvs/image_raw', '/dvs/imu', '/groundtruth/odometry']):
|
|
|
|
if topic == '/dvs/image_raw':
|
|
self._process_image(msg, t) # 传入 t
|
|
|
|
elif topic == '/dvs/imu':
|
|
self._process_imu(msg, t) # 传入 t
|
|
|
|
elif topic == '/groundtruth/odometry':
|
|
self._process_odometry(msg, t) # 传入 t
|
|
|
|
pbar.update(1)
|
|
|
|
bag.close()
|
|
|
|
# Post-processing: compute velocities from poses if not directly available
|
|
self._ensure_velocities()
|
|
|
|
# Print statistics
|
|
print(f"\nExtraction completed:")
|
|
print(f" Images: {len(self.images)}")
|
|
print(f" IMU messages: {len(self.imu_timestamps)}")
|
|
print(f" Ground truth poses: {len(self.gt_timestamps)}")
|
|
print(f" Ground truth velocities: {len(self.gt_velocities)}")
|
|
|
|
def crop_to_gt_time_range(self):
|
|
"""裁剪所有数据,只保留 GT 时间范围内的部分"""
|
|
|
|
if len(self.gt_timestamps) == 0:
|
|
print("Warning: No GT data found, skipping crop")
|
|
return
|
|
|
|
gt_start = min(self.gt_timestamps)
|
|
gt_end = max(self.gt_timestamps)
|
|
|
|
print(f"\nCropping to GT time range: {gt_start:.3f} - {gt_end:.3f} ({gt_end - gt_start:.1f}s)")
|
|
|
|
# 裁剪图像
|
|
original_img_count = len(self.images)
|
|
self.images = [(ts, img) for ts, img in self.images if gt_start <= ts <= gt_end]
|
|
print(f" Images: {original_img_count} -> {len(self.images)}")
|
|
|
|
# 裁剪 IMU
|
|
original_imu_count = len(self.imu_timestamps)
|
|
imu_filtered = [(ts, acc, gyro) for ts, acc, gyro
|
|
in zip(self.imu_timestamps, self.imu_acc, self.imu_gyro)
|
|
if gt_start <= ts <= gt_end]
|
|
|
|
if imu_filtered:
|
|
self.imu_timestamps = [item[0] for item in imu_filtered]
|
|
self.imu_acc = [item[1] for item in imu_filtered]
|
|
self.imu_gyro = [item[2] for item in imu_filtered]
|
|
print(f" IMU: {original_imu_count} -> {len(self.imu_timestamps)}")
|
|
|
|
# GT 数据本身已经在范围内,不需要裁剪
|
|
print(f" GT: {len(self.gt_timestamps)} (unchanged)")
|
|
|
|
def _process_image(self, msg, t):
|
|
"""Process grayscale image message using system time"""
|
|
try:
|
|
# Convert ROS image to OpenCV format
|
|
cv_img = self.bridge.imgmsg_to_cv2(msg, desired_encoding='mono8')
|
|
|
|
# Resize if needed
|
|
if self.image_width and self.image_height:
|
|
cv_img = cv2.resize(cv_img, (self.image_width, self.image_height),
|
|
interpolation=cv2.INTER_LINEAR)
|
|
|
|
# 使用系统物理时间,而不是 msg.header.stamp
|
|
timestamp = t.to_sec()
|
|
self.images.append((timestamp, cv_img))
|
|
|
|
except Exception as e:
|
|
print(f"Error processing image: {e}")
|
|
|
|
def _process_imu(self, msg, t):
|
|
"""Process IMU message using system time"""
|
|
timestamp = t.to_sec() # 使用系统物理时间
|
|
|
|
# Linear acceleration (m/s^2)
|
|
acc = np.array([msg.linear_acceleration.x,
|
|
msg.linear_acceleration.y,
|
|
msg.linear_acceleration.z], dtype=np.float32)
|
|
|
|
# Angular velocity (rad/s)
|
|
gyro = np.array([msg.angular_velocity.x,
|
|
msg.angular_velocity.y,
|
|
msg.angular_velocity.z], dtype=np.float32)
|
|
|
|
self.imu_timestamps.append(timestamp)
|
|
self.imu_acc.append(acc)
|
|
self.imu_gyro.append(gyro)
|
|
|
|
def _process_odometry(self, msg, t):
|
|
"""Process ground truth odometry using system time"""
|
|
timestamp = t.to_sec() # 使用系统物理时间
|
|
|
|
# Position (x, y, z)
|
|
pos = np.array([msg.pose.pose.position.x,
|
|
msg.pose.pose.position.y,
|
|
msg.pose.pose.position.z], dtype=np.float32)
|
|
|
|
# Orientation (qx, qy, qz, qw) - already normalized
|
|
quat = np.array([msg.pose.pose.orientation.x,
|
|
msg.pose.pose.orientation.y,
|
|
msg.pose.pose.orientation.z,
|
|
msg.pose.pose.orientation.w], dtype=np.float32)
|
|
|
|
pose = np.concatenate([pos, quat])
|
|
self.gt_timestamps.append(timestamp)
|
|
self.gt_poses.append(pose)
|
|
|
|
# Velocity: always compute from pose differences in post-processing
|
|
vel = None
|
|
|
|
self.gt_velocities.append(vel)
|
|
|
|
def _ensure_velocities(self):
|
|
# 数据集中 twist 数据为 0 直接利用时间戳差值
|
|
# """Compute velocities from pose differences if not directly available"""
|
|
# # Check if velocities are missing
|
|
# missing_velocities = any(v is None for v in self.gt_velocities)
|
|
|
|
# if not missing_velocities:
|
|
# return
|
|
|
|
print("Computing velocities from pose differences...")
|
|
|
|
computed_velocities = []
|
|
for i in range(len(self.gt_timestamps)):
|
|
if i == 0:
|
|
# Use forward difference for first frame
|
|
if len(self.gt_timestamps) > 1:
|
|
dt = self.gt_timestamps[1] - self.gt_timestamps[0]
|
|
if dt > 0:
|
|
# Linear velocity
|
|
v_lin = (self.gt_poses[1][:3] - self.gt_poses[0][:3]) / dt
|
|
# Angular velocity (from quaternion difference)
|
|
q0 = self.gt_poses[0][3:7]
|
|
q1 = self.gt_poses[1][3:7]
|
|
dq = R.from_quat(q1) * R.from_quat(q0).inv()
|
|
v_ang = dq.as_rotvec() / dt
|
|
computed_velocities.append(np.concatenate([v_lin, v_ang]))
|
|
else:
|
|
computed_velocities.append(np.zeros(6, dtype=np.float32))
|
|
else:
|
|
computed_velocities.append(np.zeros(6, dtype=np.float32))
|
|
else:
|
|
# Use backward difference
|
|
dt = self.gt_timestamps[i] - self.gt_timestamps[i-1]
|
|
if dt > 0:
|
|
v_lin = (self.gt_poses[i][:3] - self.gt_poses[i-1][:3]) / dt
|
|
q0 = self.gt_poses[i-1][3:7]
|
|
q1 = self.gt_poses[i][3:7]
|
|
dq = R.from_quat(q1) * R.from_quat(q0).inv()
|
|
v_ang = dq.as_rotvec() / dt
|
|
computed_velocities.append(np.concatenate([v_lin, v_ang]))
|
|
else:
|
|
computed_velocities.append(np.zeros(6, dtype=np.float32))
|
|
|
|
# Replace missing velocities
|
|
for i in range(len(self.gt_velocities)):
|
|
if self.gt_velocities[i] is None:
|
|
self.gt_velocities[i] = computed_velocities[i]
|
|
|
|
def save_imu_sequence(self):
|
|
"""Save IMU sequence as NPZ file"""
|
|
imu_data = {
|
|
'timestamps': np.array(self.imu_timestamps, dtype=np.float64),
|
|
'accelerations': np.array(self.imu_acc, dtype=np.float32),
|
|
'angular_velocities': np.array(self.imu_gyro, dtype=np.float32)
|
|
}
|
|
|
|
imu_path = self.output_dir / 'imu_sequence.npz'
|
|
imu_path.parent.mkdir(parents=True, exist_ok=True)
|
|
np.savez_compressed(imu_path, **imu_data)
|
|
|
|
print(f"Saved IMU sequence: {imu_path}")
|
|
return imu_path
|
|
|
|
def align_ground_truth_to_images(self) -> List[Tuple[float, np.ndarray, np.ndarray, np.ndarray]]:
|
|
"""Align ground truth (pose + velocity) to each image using nearest timestamp"""
|
|
aligned_gt = []
|
|
|
|
gt_timestamps = np.array(self.gt_timestamps)
|
|
gt_poses = np.array(self.gt_poses)
|
|
gt_vels = np.array(self.gt_velocities)
|
|
|
|
for img_ts, img in tqdm(self.images, desc="Aligning ground truth to images"):
|
|
idx = np.argmin(np.abs(gt_timestamps - img_ts))
|
|
time_diff = abs(gt_timestamps[idx] - img_ts)
|
|
|
|
if time_diff < 0.1:
|
|
aligned_gt.append((img_ts, img, gt_poses[idx], gt_vels[idx])) # 保存图像
|
|
|
|
return aligned_gt
|
|
|
|
def save_as_webdataset(self, aligned_gt: List[Tuple[float, np.ndarray, np.ndarray, np.ndarray]]):
|
|
"""Save images and aligned ground truth as WebDataset tar files"""
|
|
|
|
num_shards = (len(aligned_gt) + self.shard_size - 1) // self.shard_size
|
|
|
|
print(f"Saving {len(aligned_gt)} samples into {num_shards} shards...")
|
|
|
|
for shard_idx in range(num_shards):
|
|
start_idx = shard_idx * self.shard_size
|
|
end_idx = min((shard_idx + 1) * self.shard_size, len(aligned_gt))
|
|
|
|
tar_path = self.output_dir / f'shard_{shard_idx:04d}.tar'
|
|
|
|
with wds.TarWriter(str(tar_path)) as sink:
|
|
for local_idx, (img_ts, img, pose, vel) in enumerate(aligned_gt):
|
|
|
|
# Encode image as JPEG
|
|
_, img_encoded = cv2.imencode('.jpg', img,
|
|
[cv2.IMWRITE_JPEG_QUALITY, 85])
|
|
img_bytes = img_encoded.tobytes()
|
|
|
|
# Prepare metadata
|
|
sample_key = f'frame_{local_idx:08d}'
|
|
|
|
# Write to tar
|
|
sink.write({
|
|
'__key__': sample_key,
|
|
'jpg': img_bytes,
|
|
'ts': np.array([img_ts], dtype=np.float64).tobytes(),
|
|
'pose': pose.astype(np.float32).tobytes(),
|
|
'vel': vel.astype(np.float32).tobytes()
|
|
})
|
|
|
|
print(f" Saved {tar_path} ({end_idx - start_idx} samples)")
|
|
|
|
def save_metadata(self):
|
|
"""Save dataset metadata"""
|
|
metadata = {
|
|
'dataset_name': self.dataset_name,
|
|
'source_bag': str(self.bag_path),
|
|
'num_images': len(self.images),
|
|
'num_imu_messages': len(self.imu_timestamps),
|
|
'num_ground_truth': len(self.gt_timestamps),
|
|
'image_size': [self.image_width, self.image_height],
|
|
'imu_frequency_hz': len(self.imu_timestamps) / (self.imu_timestamps[-1] - self.imu_timestamps[0]) if len(self.imu_timestamps) > 1 else 0,
|
|
'camera_frequency_hz': len(self.images) / (self.images[-1][0] - self.images[0][0]) if len(self.images) > 1 else 0,
|
|
'gt_frequency_hz': len(self.gt_timestamps) / (self.gt_timestamps[-1] - self.gt_timestamps[0]) if len(self.gt_timestamps) > 1 else 0,
|
|
'coordinate_system': 'horizontal (z aligned with gravity, assumed from GT)',
|
|
'velocity_dimensions': 6, # (vx, vy, vz, wx, wy, wz)
|
|
}
|
|
|
|
metadata_path = self.output_dir / 'metadata.json'
|
|
with open(metadata_path, 'w') as f:
|
|
json.dump(metadata, f, indent=2)
|
|
|
|
print(f"Saved metadata: {metadata_path}")
|
|
|
|
def convert(self):
|
|
"""Main conversion pipeline"""
|
|
print(f"\n{'='*60}")
|
|
print(f"Converting: {self.bag_path.name}")
|
|
print(f"Output: {self.output_dir}")
|
|
print(f"{'='*60}\n")
|
|
|
|
# Step 1: Extract all data from bag
|
|
self.extract_all_data()
|
|
|
|
# 裁剪掉无 GT 的时间段
|
|
self.crop_to_gt_time_range()
|
|
|
|
# Step 2: Save IMU sequence
|
|
self.save_imu_sequence()
|
|
|
|
# # Step 3: Align ground truth to images
|
|
aligned_gt = self.align_ground_truth_to_images()
|
|
|
|
if len(aligned_gt) == 0:
|
|
print("Error: No aligned ground truth found!")
|
|
sys.exit(1)
|
|
|
|
# # Step 4: Save as WebDataset
|
|
self.save_as_webdataset(aligned_gt)
|
|
|
|
# # Step 5: Save metadata
|
|
self.save_metadata()
|
|
|
|
self.diagnose_timestamps()
|
|
|
|
print(f"\n✅ Conversion completed for {self.bag_path.name}")
|
|
|
|
|
|
|
|
def diagnose_timestamps(self):
|
|
"""Print timestamp ranges for debugging"""
|
|
img_timestamps = [t for t, _ in self.images]
|
|
gt_timestamps = self.gt_timestamps
|
|
|
|
print(f"Image timestamps: {min(img_timestamps):.3f} - {max(img_timestamps):.3f}")
|
|
print(f"GT timestamps: {min(gt_timestamps):.3f} - {max(gt_timestamps):.3f}")
|
|
print(f"Image duration: {max(img_timestamps) - min(img_timestamps):.3f}s")
|
|
print(f"GT duration: {max(gt_timestamps) - min(gt_timestamps):.3f}s")
|
|
|
|
# Check if there's a constant offset
|
|
if len(img_timestamps) > 0 and len(gt_timestamps) > 0:
|
|
offset = gt_timestamps[0] - img_timestamps[0]
|
|
print(f"Initial offset (first GT - first image): {offset:.3f}s")
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Convert ROS bag to WebDataset format')
|
|
parser.add_argument('--bag', type=str, required=True, help='Path to ROS bag file')
|
|
parser.add_argument('--output', type=str, default='./dataset', help='Output directory')
|
|
parser.add_argument('--name', type=str, default=None, help='Dataset name (default: bag filename without extension)')
|
|
parser.add_argument('--shard_size', type=int, default=2000, help='Number of samples per shard')
|
|
parser.add_argument('--width', type=int, default=320, help='Image width (resize)')
|
|
parser.add_argument('--height', type=int, default=240, help='Image height (resize)')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Validate inputs
|
|
if not os.path.exists(args.bag):
|
|
print(f"Error: Bag file not found: {args.bag}")
|
|
sys.exit(1)
|
|
|
|
# Set dataset name
|
|
if args.name is None:
|
|
args.name = Path(args.bag).stem
|
|
|
|
# Run conversion
|
|
converter = BagToWebDataset(
|
|
bag_path=args.bag,
|
|
output_dir=args.output,
|
|
dataset_name=args.name,
|
|
shard_size=args.shard_size,
|
|
image_width=args.width,
|
|
image_height=args.height
|
|
)
|
|
|
|
converter.convert()
|
|
|
|
if __name__ == '__main__':
|
|
main() |