Fix this bug when setting distillation-type to none

This commit is contained in:
caitianren
2023-11-29 20:15:00 +08:00
parent cd1f854e59
commit 0ddadad723

View File

@@ -437,7 +437,7 @@ class SwiftFormer(nn.Module):
if not self.training: if not self.training:
cls_out = (cls_out[0] + cls_out[1]) / 2 cls_out = (cls_out[0] + cls_out[1]) / 2
else: else:
cls_out = self.head(x.mean(-2)) cls_out = self.head(x.flatten(2).mean(-1))
# For image classification # For image classification
return cls_out return cls_out