初步可跑通,但loss计算有问题,不收敛
This commit is contained in:
@@ -80,10 +80,13 @@ class VideoFrameDataset(Dataset):
|
||||
else:
|
||||
self.transform = transform
|
||||
|
||||
# Normalization (ImageNet stats)
|
||||
# Normalization for Y channel (single channel)
|
||||
# Compute average of ImageNet RGB means and stds
|
||||
y_mean = (0.485 + 0.456 + 0.406) / 3.0
|
||||
y_std = (0.229 + 0.224 + 0.225) / 3.0
|
||||
self.normalize = transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225]
|
||||
mean=[y_mean],
|
||||
std=[y_std]
|
||||
)
|
||||
|
||||
def _default_transform(self):
|
||||
@@ -114,8 +117,8 @@ class VideoFrameDataset(Dataset):
|
||||
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Returns:
|
||||
input_frames: [3 * num_frames, H, W] concatenated input frames
|
||||
target_frame: [3, H, W] target frame to predict
|
||||
input_frames: [num_frames, H, W] concatenated input frames (Y channel only)
|
||||
target_frame: [1, H, W] target frame to predict (Y channel only)
|
||||
temporal_idx: temporal index of target frame (for contrastive loss)
|
||||
"""
|
||||
video_idx, start_idx = self.frame_indices[idx]
|
||||
@@ -141,23 +144,27 @@ class VideoFrameDataset(Dataset):
|
||||
if self.transform:
|
||||
target_frame = self.transform(target_frame)
|
||||
|
||||
# Convert to tensors and normalize
|
||||
# Convert to tensors, normalize, and convert to grayscale (Y channel)
|
||||
input_tensors = []
|
||||
for frame in input_frames:
|
||||
tensor = transforms.ToTensor()(frame)
|
||||
tensor = self.normalize(tensor)
|
||||
input_tensors.append(tensor)
|
||||
tensor = transforms.ToTensor()(frame) # [3, H, W]
|
||||
# Convert RGB to grayscale using weighted sum
|
||||
# Y = 0.2989 * R + 0.5870 * G + 0.1140 * B (same as PIL)
|
||||
gray = (0.2989 * tensor[0] + 0.5870 * tensor[1] + 0.1140 * tensor[2]).unsqueeze(0) # [1, H, W]
|
||||
gray = self.normalize(gray) # normalize with single-channel stats (mean/std broadcast)
|
||||
input_tensors.append(gray)
|
||||
|
||||
target_tensor = transforms.ToTensor()(target_frame)
|
||||
target_tensor = self.normalize(target_tensor)
|
||||
target_tensor = transforms.ToTensor()(target_frame) # [3, H, W]
|
||||
target_gray = (0.2989 * target_tensor[0] + 0.5870 * target_tensor[1] + 0.1140 * target_tensor[2]).unsqueeze(0)
|
||||
target_gray = self.normalize(target_gray)
|
||||
|
||||
# Concatenate input frames along channel dimension
|
||||
input_concatenated = torch.cat(input_tensors, dim=0)
|
||||
input_concatenated = torch.cat(input_tensors, dim=0) # [num_frames, H, W]
|
||||
|
||||
# Temporal index (for contrastive loss)
|
||||
temporal_idx = torch.tensor(self.num_frames, dtype=torch.long)
|
||||
|
||||
return input_concatenated, target_tensor, temporal_idx
|
||||
return input_concatenated, target_gray, temporal_idx
|
||||
|
||||
|
||||
class SyntheticVideoDataset(Dataset):
|
||||
@@ -174,10 +181,12 @@ class SyntheticVideoDataset(Dataset):
|
||||
self.frame_size = frame_size
|
||||
self.is_train = is_train
|
||||
|
||||
# Normalization
|
||||
# Normalization for Y channel (single channel)
|
||||
y_mean = (0.485 + 0.456 + 0.406) / 3.0
|
||||
y_std = (0.229 + 0.224 + 0.225) / 3.0
|
||||
self.normalize = transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225]
|
||||
mean=[y_mean],
|
||||
std=[y_std]
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
|
||||
Reference in New Issue
Block a user