Merge pull request #12 from ThomasCai/main
Fix the issue when the distillation type is set to none.
This commit is contained in:
@@ -437,7 +437,7 @@ class SwiftFormer(nn.Module):
|
||||
if not self.training:
|
||||
cls_out = (cls_out[0] + cls_out[1]) / 2
|
||||
else:
|
||||
cls_out = self.head(x.mean(-2))
|
||||
cls_out = self.head(x.flatten(2).mean(-1))
|
||||
# For image classification
|
||||
return cls_out
|
||||
|
||||
|
||||
Reference in New Issue
Block a user