初步可跑通,但loss计算有问题,不收敛
This commit is contained in:
@@ -6,9 +6,9 @@ import copy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.models.layers import DropPath, trunc_normal_
|
||||
from timm.models.registry import register_model
|
||||
from timm.models.layers.helpers import to_2tuple
|
||||
from timm.layers import DropPath, trunc_normal_
|
||||
from timm.models import register_model
|
||||
from timm.layers import to_2tuple
|
||||
import einops
|
||||
|
||||
SwiftFormer_width = {
|
||||
|
||||
@@ -7,7 +7,7 @@ from .swiftformer import (
|
||||
SwiftFormer, SwiftFormer_depth, SwiftFormer_width,
|
||||
stem, Embedding, Stage
|
||||
)
|
||||
from timm.models.layers import DropPath, trunc_normal_
|
||||
from timm.layers import DropPath, trunc_normal_
|
||||
|
||||
|
||||
class DecoderBlock(nn.Module):
|
||||
@@ -31,7 +31,7 @@ class DecoderBlock(nn.Module):
|
||||
|
||||
class FramePredictionDecoder(nn.Module):
|
||||
"""Lightweight decoder for frame prediction with optional skip connections"""
|
||||
def __init__(self, embed_dims, output_channels=3, use_skip=False):
|
||||
def __init__(self, embed_dims, output_channels=1, use_skip=False):
|
||||
super().__init__()
|
||||
self.use_skip = use_skip
|
||||
# Reverse the embed_dims for decoder
|
||||
@@ -53,11 +53,11 @@ class FramePredictionDecoder(nn.Module):
|
||||
decoder_dims[2], decoder_dims[3],
|
||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||
))
|
||||
# stage2 to original resolution (4x upsampling total)
|
||||
# stage2 to original resolution (now 8x upsampling total with stride 4)
|
||||
self.blocks.append(nn.Sequential(
|
||||
nn.ConvTranspose2d(
|
||||
decoder_dims[3], 32,
|
||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||
kernel_size=3, stride=4, padding=1, output_padding=3
|
||||
),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(inplace=True),
|
||||
@@ -104,7 +104,7 @@ class SwiftFormerTemporal(nn.Module):
|
||||
"""
|
||||
SwiftFormer with temporal input for frame prediction.
|
||||
Input: [B, num_frames, H, W] (Y channel only)
|
||||
Output: predicted frame [B, 3, H, W] and optional representation
|
||||
Output: predicted frame [B, 1, H, W] and optional representation
|
||||
"""
|
||||
def __init__(self,
|
||||
model_name='XS',
|
||||
@@ -155,7 +155,7 @@ class SwiftFormerTemporal(nn.Module):
|
||||
|
||||
# Frame prediction decoder
|
||||
if use_decoder:
|
||||
self.decoder = FramePredictionDecoder(embed_dims, output_channels=3)
|
||||
self.decoder = FramePredictionDecoder(embed_dims, output_channels=1)
|
||||
|
||||
# Representation head for pose/velocity prediction
|
||||
if use_representation_head:
|
||||
@@ -201,7 +201,7 @@ class SwiftFormerTemporal(nn.Module):
|
||||
x: input frames of shape [B, num_frames, H, W]
|
||||
Returns:
|
||||
If return_features is False:
|
||||
pred_frame: predicted frame [B, 3, H, W] (or None)
|
||||
pred_frame: predicted frame [B, 1, H, W] (or None)
|
||||
representation: optional representation vector [B, representation_dim] (or None)
|
||||
If return_features is True:
|
||||
pred_frame, representation, features (list of stage features)
|
||||
|
||||
Reference in New Issue
Block a user