Compare commits
4 Commits
cb9936542e
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| 02d429282e | |||
| 1369edaad7 | |||
| b5abbc239d | |||
| e7e773a48f |
@@ -71,9 +71,14 @@ class EventProcessor:
|
|||||||
frame: np.ndarray, shape (H, W) or (H, W, C), uint8 or float.
|
frame: np.ndarray, shape (H, W) or (H, W, C), uint8 or float.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
When threshold > 0:
|
||||||
events_binary: np.ndarray (H, W), values in {-1, 0, +1}
|
events_binary: np.ndarray (H, W), values in {-1, 0, +1}
|
||||||
events_strength: np.ndarray (H, W), values in [-1, 1]
|
events_strength: np.ndarray (H, W), values in [-1, 1]
|
||||||
event_count: int, number of non-zero events
|
event_count: int, number of non-zero events
|
||||||
|
When threshold == 0 (raw output, no thresholding):
|
||||||
|
change_raw: np.ndarray (H, W), raw log/linear brightness change (float32)
|
||||||
|
change_raw: same as above
|
||||||
|
event_count: int, number of pixels with non-zero change
|
||||||
"""
|
"""
|
||||||
brightness = self._to_grayscale(frame)
|
brightness = self._to_grayscale(frame)
|
||||||
|
|
||||||
@@ -81,10 +86,18 @@ class EventProcessor:
|
|||||||
if self.prev_brightness is None:
|
if self.prev_brightness is None:
|
||||||
self.prev_brightness = brightness
|
self.prev_brightness = brightness
|
||||||
h, w = brightness.shape
|
h, w = brightness.shape
|
||||||
|
if self.threshold == 0:
|
||||||
|
return np.zeros((h, w), dtype=np.float32), np.zeros((h, w), dtype=np.float32), 0
|
||||||
return np.zeros((h, w), dtype=np.int8), np.zeros((h, w), dtype=np.float32), 0
|
return np.zeros((h, w), dtype=np.int8), np.zeros((h, w), dtype=np.float32), 0
|
||||||
|
|
||||||
change = self._compute_change(brightness)
|
change = self._compute_change(brightness)
|
||||||
|
|
||||||
|
# threshold == 0: raw mode, skip thresholding
|
||||||
|
if self.threshold == 0:
|
||||||
|
self.prev_brightness = brightness
|
||||||
|
change_f32 = change.astype(np.float32)
|
||||||
|
return change_f32, change_f32, int(np.count_nonzero(change))
|
||||||
|
|
||||||
if self.auto_threshold:
|
if self.auto_threshold:
|
||||||
self._update_auto_threshold(change)
|
self._update_auto_threshold(change)
|
||||||
|
|
||||||
|
|||||||
@@ -170,42 +170,43 @@ def main():
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
print(f"Loaded checkpoint from {args.checkpoint} (epoch={ckpt.get('epoch', '?')})")
|
print(f"Loaded checkpoint from {args.checkpoint} (epoch={ckpt.get('epoch', '?')})")
|
||||||
|
|
||||||
# Validation loader (use test scenes for final eval)
|
# Evaluate each scene independently → NaN gaps prevent plot mixing
|
||||||
from src.velocity_prediction.config import TEST_SCENES
|
from src.velocity_prediction.config import TEST_SCENES
|
||||||
|
all_preds, all_targets = [], []
|
||||||
|
scene_rmses = []
|
||||||
|
|
||||||
|
for scene in TEST_SCENES:
|
||||||
loader = create_val_loader(
|
loader = create_val_loader(
|
||||||
scene_names=TEST_SCENES,
|
scene_names=[scene],
|
||||||
seq_len=train_cfg.seq_len,
|
seq_len=train_cfg.seq_len,
|
||||||
batch_size=train_cfg.batch_size,
|
batch_size=train_cfg.batch_size,
|
||||||
num_workers=2,
|
num_workers=2,
|
||||||
event_threshold=train_cfg.event_threshold,
|
event_threshold=train_cfg.event_threshold,
|
||||||
event_use_log=train_cfg.event_use_log,
|
event_use_log=train_cfg.event_use_log,
|
||||||
)
|
)
|
||||||
|
|
||||||
# # ── Quick event diagnostics: inspect one batch ───────────────
|
|
||||||
# print("\n========== Event Frame Diagnostics ==========")
|
|
||||||
# sample_batch = next(iter(loader))
|
|
||||||
# ev = sample_batch["events"] # (B, S, 1, H, W)
|
|
||||||
# print(f"Events shape: {ev.shape}")
|
|
||||||
# print(f"Events dtype: {ev.dtype}")
|
|
||||||
# print(f"Events value counts: -1: {(ev == -1).sum().item()}, "
|
|
||||||
# f"0: {(ev == 0).sum().item()}, +1: {(ev == 1).sum().item()}")
|
|
||||||
# total_el = ev.numel()
|
|
||||||
# nonzero = (ev != 0).sum().item()
|
|
||||||
# print(f"Non-zero ratio: {nonzero / total_el:.6f} ({nonzero}/{total_el})")
|
|
||||||
# print(f"Per-sample non-zero: {[(ev[b] != 0).sum().item() for b in range(min(4, ev.shape[0]))]}")
|
|
||||||
# print("=============================================\n")
|
|
||||||
|
|
||||||
# Evaluate
|
|
||||||
results = evaluate(model, loader, device)
|
results = evaluate(model, loader, device)
|
||||||
print(f"\nEvaluation results on test scenes: {TEST_SCENES}")
|
n = len(results["preds"])
|
||||||
print(f" RMSE vx: {results['rmse_x']:.4f} m/s")
|
print(f" [{scene}] RMSE vx={results['rmse_x']:.4f} vy={results['rmse_y']:.4f} "
|
||||||
print(f" RMSE vy: {results['rmse_y']:.4f} m/s")
|
f"xy={results['rmse_xy']:.4f} samples={n}")
|
||||||
print(f" RMSE xy: {results['rmse_xy']:.4f} m/s")
|
scene_rmses.append(results["rmse_xy"])
|
||||||
|
|
||||||
# Plots
|
all_preds.append(results["preds"])
|
||||||
|
all_targets.append(results["targets"])
|
||||||
|
# NaN separator → plot won't connect discontinuous scenes
|
||||||
|
sep = np.full((1, 2), np.nan, dtype=np.float32)
|
||||||
|
all_preds.append(sep)
|
||||||
|
all_targets.append(sep)
|
||||||
|
|
||||||
|
# Overall RMSE = mean across scenes (unweighted, avoids scene size bias)
|
||||||
|
rmse_xy = np.mean(scene_rmses)
|
||||||
|
print(f"\nOverall ({len(TEST_SCENES)} scenes, mean across scenes): RMSE xy={rmse_xy:.4f} m/s")
|
||||||
|
|
||||||
|
# Plots (with NaN gaps between scenes)
|
||||||
if args.plot:
|
if args.plot:
|
||||||
plot_results(results["preds"], results["targets"], "eval_velocity.png")
|
preds_cat = np.concatenate(all_preds, axis=0)
|
||||||
plot_scatter(results["preds"], results["targets"], "eval_scatter.png")
|
targets_cat = np.concatenate(all_targets, axis=0)
|
||||||
|
plot_results(preds_cat, targets_cat, "eval_velocity.png")
|
||||||
|
plot_scatter(preds_cat, targets_cat, "eval_scatter.png")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ class CNNEncoder(nn.Module):
|
|||||||
for out_ch in channels:
|
for out_ch in channels:
|
||||||
layers.extend([
|
layers.extend([
|
||||||
nn.Conv2d(in_ch, out_ch, kernel_size=cfg.kernel_size, padding=cfg.kernel_size // 2),
|
nn.Conv2d(in_ch, out_ch, kernel_size=cfg.kernel_size, padding=cfg.kernel_size // 2),
|
||||||
# nn.BatchNorm2d(out_ch) if cfg.use_bn else nn.Identity(),
|
nn.BatchNorm2d(out_ch) if cfg.use_bn else nn.Identity(),
|
||||||
nn.Identity(),
|
nn.Identity(),
|
||||||
nn.LeakyReLU(inplace=True),
|
nn.LeakyReLU(inplace=True),
|
||||||
nn.MaxPool2d(cfg.pool_size),
|
nn.MaxPool2d(cfg.pool_size),
|
||||||
@@ -119,8 +119,10 @@ class VelocityPredictionModel(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# # Small init for the final layer: start from near-zero output
|
# # Small init for the final layer: start from near-zero output
|
||||||
# self.head[-1].weight.data.mul_(0.01)
|
self.head[-1].weight.data.mul_(0.01)
|
||||||
# self.head[-1].bias.data.zero_()
|
self.head[-1].bias.data.zero_()
|
||||||
|
# nn.init.uniform_(self.head[-1].weight, -0.001, 0.001)
|
||||||
|
# nn.init.zeros_(self.head[-1].bias)
|
||||||
|
|
||||||
def forward(self, events: torch.Tensor, tilt: torch.Tensor) -> torch.Tensor:
|
def forward(self, events: torch.Tensor, tilt: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@@ -132,9 +134,9 @@ class VelocityPredictionModel(nn.Module):
|
|||||||
v_body: (B, 2) predicted body-frame [v_forward, v_lateral] at the last timestep
|
v_body: (B, 2) predicted body-frame [v_forward, v_lateral] at the last timestep
|
||||||
"""
|
"""
|
||||||
# Per-frame encoding
|
# Per-frame encoding
|
||||||
# cnn_feat = self.cnn(events) # (B, S, 256)
|
cnn_feat = self.cnn(events) # (B, S, 256)
|
||||||
B, S = events.shape[:2]
|
# B, S = events.shape[:2]
|
||||||
cnn_feat = events.new_zeros(B, S, self.cnn.out_dim) # 全零替代
|
# cnn_feat = events.new_zeros(B, S, self.cnn.out_dim) # 全零替代
|
||||||
|
|
||||||
pose_feat = self.pose_mlp(tilt) # (B, S, 64)
|
pose_feat = self.pose_mlp(tilt) # (B, S, 64)
|
||||||
|
|
||||||
|
|||||||
@@ -32,11 +32,13 @@ def train_one_epoch(
|
|||||||
loader,
|
loader,
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
criterion: nn.Module,
|
criterion: nn.Module,
|
||||||
|
scaler: torch.cuda.amp.GradScaler,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
epoch: int,
|
epoch: int,
|
||||||
writer: SummaryWriter,
|
writer: SummaryWriter,
|
||||||
log_interval: int = 50,
|
log_interval: int = 50,
|
||||||
global_step: int = 0,
|
global_step: int = 0,
|
||||||
|
use_amp: bool = True,
|
||||||
) -> tuple[float, int]:
|
) -> tuple[float, int]:
|
||||||
"""Train for one epoch. Returns (avg_loss, updated_global_step)."""
|
"""Train for one epoch. Returns (avg_loss, updated_global_step)."""
|
||||||
model.train()
|
model.train()
|
||||||
@@ -50,14 +52,15 @@ def train_one_epoch(
|
|||||||
target = batch["v_body_target"].to(device) # (B, S, 2)
|
target = batch["v_body_target"].to(device) # (B, S, 2)
|
||||||
|
|
||||||
# Predict velocity for the last frame in the sequence
|
# Predict velocity for the last frame in the sequence
|
||||||
|
with torch.amp.autocast(device.type, enabled=use_amp):
|
||||||
pred = model(events, tilt) # (B, 2)
|
pred = model(events, tilt) # (B, 2)
|
||||||
target_last = target[:, -1, :] # (B, 2)
|
target_last = target[:, -1, :] # (B, 2)
|
||||||
|
|
||||||
loss = criterion(pred, target_last)
|
loss = criterion(pred, target_last)
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
scaler.scale(loss).backward()
|
||||||
optimizer.step()
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
|
||||||
total_loss += loss.item()
|
total_loss += loss.item()
|
||||||
num_batches += 1
|
num_batches += 1
|
||||||
@@ -79,6 +82,7 @@ def validate(
|
|||||||
loader,
|
loader,
|
||||||
criterion: nn.Module,
|
criterion: nn.Module,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
use_amp: bool = True,
|
||||||
) -> float:
|
) -> float:
|
||||||
"""Validate. Returns average loss."""
|
"""Validate. Returns average loss."""
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -90,10 +94,11 @@ def validate(
|
|||||||
tilt = batch["tilt"].to(device)
|
tilt = batch["tilt"].to(device)
|
||||||
target = batch["v_body_target"].to(device)
|
target = batch["v_body_target"].to(device)
|
||||||
|
|
||||||
|
with torch.amp.autocast(device.type, enabled=use_amp):
|
||||||
pred = model(events, tilt)
|
pred = model(events, tilt)
|
||||||
target_last = target[:, -1, :]
|
target_last = target[:, -1, :]
|
||||||
|
|
||||||
loss = criterion(pred, target_last)
|
loss = criterion(pred, target_last)
|
||||||
|
|
||||||
total_loss += loss.item()
|
total_loss += loss.item()
|
||||||
num_batches += 1
|
num_batches += 1
|
||||||
|
|
||||||
@@ -107,7 +112,10 @@ def main():
|
|||||||
help="CUDA device, e.g. 'cuda:0', 'cuda:1' (default: 'cuda')")
|
help="CUDA device, e.g. 'cuda:0', 'cuda:1' (default: 'cuda')")
|
||||||
parser.add_argument("--resume", type=str, default=None,
|
parser.add_argument("--resume", type=str, default=None,
|
||||||
help="Path to checkpoint .pt file to resume training from")
|
help="Path to checkpoint .pt file to resume training from")
|
||||||
|
parser.add_argument("--amp", action=argparse.BooleanOptionalAction, default=True,
|
||||||
|
help="Enable Automatic Mixed Precision (default: True)")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
use_amp = args.amp
|
||||||
|
|
||||||
set_seed(train_cfg.seed)
|
set_seed(train_cfg.seed)
|
||||||
device = torch.device(args.device if torch.cuda.is_available() and "cuda" in args.device else "cpu")
|
device = torch.device(args.device if torch.cuda.is_available() and "cuda" in args.device else "cpu")
|
||||||
@@ -116,8 +124,10 @@ def main():
|
|||||||
# Create model
|
# Create model
|
||||||
model = VelocityPredictionModel()
|
model = VelocityPredictionModel()
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
scaler = torch.amp.GradScaler(device.type, enabled=use_amp)
|
||||||
total_params = count_parameters(model)
|
total_params = count_parameters(model)
|
||||||
print(f"Model parameters: {total_params:,} ({total_params/1e6:.3f} M)")
|
print(f"Model parameters: {total_params:,} ({total_params/1e6:.3f} M)")
|
||||||
|
print(f"AMP: {'enabled' if use_amp else 'disabled'}")
|
||||||
|
|
||||||
# Data loaders
|
# Data loaders
|
||||||
train_loader = create_train_loader(
|
train_loader = create_train_loader(
|
||||||
@@ -196,11 +206,12 @@ def main():
|
|||||||
epoch_start = time.time()
|
epoch_start = time.time()
|
||||||
|
|
||||||
train_loss, global_step = train_one_epoch(
|
train_loss, global_step = train_one_epoch(
|
||||||
model, train_loader, optimizer, criterion, device, epoch, writer,
|
model, train_loader, optimizer, criterion, scaler, device, epoch, writer,
|
||||||
log_interval=train_cfg.log_interval,
|
log_interval=train_cfg.log_interval,
|
||||||
global_step=global_step,
|
global_step=global_step,
|
||||||
|
use_amp=use_amp,
|
||||||
)
|
)
|
||||||
val_loss = validate(model, val_loader, criterion, device)
|
val_loss = validate(model, val_loader, criterion, device, use_amp=use_amp)
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
||||||
epoch_time = time.time() - epoch_start
|
epoch_time = time.time() - epoch_start
|
||||||
|
|||||||
@@ -139,7 +139,7 @@ def build_train_transform(event_threshold=0.1, event_use_log=True):
|
|||||||
SimulateEvents(threshold=event_threshold, use_log=event_use_log),
|
SimulateEvents(threshold=event_threshold, use_log=event_use_log),
|
||||||
ComputeTilt(),
|
ComputeTilt(),
|
||||||
ComputeBodyVelocity(),
|
ComputeBodyVelocity(),
|
||||||
NormalizeVelocity(),
|
# NormalizeVelocity(),
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
@@ -150,5 +150,5 @@ def build_val_transform(event_threshold=0.1, event_use_log=True):
|
|||||||
SimulateEvents(threshold=event_threshold, use_log=event_use_log),
|
SimulateEvents(threshold=event_threshold, use_log=event_use_log),
|
||||||
ComputeTilt(),
|
ComputeTilt(),
|
||||||
ComputeBodyVelocity(),
|
ComputeBodyVelocity(),
|
||||||
NormalizeVelocity(),
|
# NormalizeVelocity(),
|
||||||
])
|
])
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from src.velocity_prediction.utils import (
|
|||||||
R_ODOM_TO_BODY,
|
R_ODOM_TO_BODY,
|
||||||
)
|
)
|
||||||
from src.velocity_prediction.config import DATASET_ROOT, VELOCITY_MEAN, VELOCITY_STD
|
from src.velocity_prediction.config import DATASET_ROOT, VELOCITY_MEAN, VELOCITY_STD
|
||||||
|
from src.event_utils import EventProcessor
|
||||||
|
|
||||||
|
|
||||||
# ──────────────────────────── Data loading ────────────────────────────
|
# ──────────────────────────── Data loading ────────────────────────────
|
||||||
@@ -140,6 +141,8 @@ def draw_pose_overlay(
|
|||||||
euler: np.ndarray,
|
euler: np.ndarray,
|
||||||
frame_idx: int,
|
frame_idx: int,
|
||||||
ts: float,
|
ts: float,
|
||||||
|
events: np.ndarray | None = None,
|
||||||
|
show_events: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Draw body-frame pose and velocity information onto the image.
|
Draw body-frame pose and velocity information onto the image.
|
||||||
@@ -285,6 +288,21 @@ def draw_pose_overlay(
|
|||||||
cv2.LINE_AA,
|
cv2.LINE_AA,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ── Event overlay (gradient temporal intensity) ──
|
||||||
|
if show_events and events is not None:
|
||||||
|
limit = max(np.abs(events).max(), 1e-6)
|
||||||
|
norm = np.clip(events / limit, -1.0, 1.0)
|
||||||
|
pos = norm > 0
|
||||||
|
neg = norm < 0
|
||||||
|
intensity = np.abs(norm) # [0, 1] magnitude
|
||||||
|
overlay = np.zeros_like(display, dtype=np.uint8)
|
||||||
|
# bg = np.ones_like(display, dtype=np.uint8) * 255
|
||||||
|
# Color intensity proportional to |norm|: dark → bright
|
||||||
|
overlay[pos, 1] = (255 * intensity[pos]).astype(np.uint8) # green channel
|
||||||
|
overlay[neg, 2] = (255 * intensity[neg]).astype(np.uint8) # red channel
|
||||||
|
# display = cv2.addWeighted(bg, 0.5, overlay, 1.0, 0)
|
||||||
|
display = cv2.addWeighted(display, 0.5, overlay, 1.0, 0)
|
||||||
|
|
||||||
return display
|
return display
|
||||||
|
|
||||||
|
|
||||||
@@ -297,6 +315,7 @@ def create_video(
|
|||||||
fps: float = 30.0,
|
fps: float = 30.0,
|
||||||
max_frames: int | None = None,
|
max_frames: int | None = None,
|
||||||
show: bool = False,
|
show: bool = False,
|
||||||
|
show_events: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Read scene data, overlay pose info, and write to video file (or show).
|
Read scene data, overlay pose info, and write to video file (or show).
|
||||||
@@ -316,6 +335,9 @@ def create_video(
|
|||||||
# Reset attitude offset for this scene
|
# Reset attitude offset for this scene
|
||||||
reset_attitude_offset()
|
reset_attitude_offset()
|
||||||
|
|
||||||
|
# Event processor (threshold=0 → raw temporal intensity)
|
||||||
|
event_processor = EventProcessor(threshold=0.3, use_log=True) if show_events else None
|
||||||
|
|
||||||
# Get dimensions from first frame
|
# Get dimensions from first frame
|
||||||
h, w = frames[0]["img"].shape
|
h, w = frames[0]["img"].shape
|
||||||
|
|
||||||
@@ -334,6 +356,12 @@ def create_video(
|
|||||||
for i, frame_data in enumerate(frames):
|
for i, frame_data in enumerate(frames):
|
||||||
q_raw = frame_data["pose"][3:7] # [qx, qy, qz, qw] world→odom
|
q_raw = frame_data["pose"][3:7] # [qx, qy, qz, qw] world→odom
|
||||||
|
|
||||||
|
# Compute events if enabled
|
||||||
|
if event_processor is not None:
|
||||||
|
events_binary, _, _ = event_processor(frame_data["img"])
|
||||||
|
else:
|
||||||
|
events_binary = None
|
||||||
|
|
||||||
# Body up vector (pitch & roll only, no yaw) — matches DiffPhysDrone
|
# Body up vector (pitch & roll only, no yaw) — matches DiffPhysDrone
|
||||||
body_up = body_up_vector_np(q_raw) # (3,) unit vector
|
body_up = body_up_vector_np(q_raw) # (3,) unit vector
|
||||||
|
|
||||||
@@ -354,6 +382,8 @@ def create_video(
|
|||||||
euler=euler_deg,
|
euler=euler_deg,
|
||||||
frame_idx=i,
|
frame_idx=i,
|
||||||
ts=frame_data["ts"],
|
ts=frame_data["ts"],
|
||||||
|
events=events_binary,
|
||||||
|
show_events=show_events,
|
||||||
)
|
)
|
||||||
|
|
||||||
if show:
|
if show:
|
||||||
@@ -404,6 +434,9 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--show", action="store_true", help="Display on screen instead of saving video"
|
"--show", action="store_true", help="Display on screen instead of saving video"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--show-events", action="store_true", help="Overlay event frames (green=+1, red=-1)"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Collect scenes to process
|
# Collect scenes to process
|
||||||
@@ -436,6 +469,7 @@ def main():
|
|||||||
fps=args.fps,
|
fps=args.fps,
|
||||||
max_frames=args.max_frames,
|
max_frames=args.max_frames,
|
||||||
show=args.show,
|
show=args.show,
|
||||||
|
show_events=args.show_events,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user