diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index e594378b..23367bbf 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -889,6 +889,8 @@ class CustomAdapter(torch.nn.Module): yield from attn_processor.parameters(recurse) if self.config.train_image_encoder: yield from self.vision_encoder.parameters(recurse) + if self.config.num_tokens: + yield from self.vd_adapter.resampler.parameters(recurse) elif self.config.type == 'te_augmenter': yield from self.te_augmenter.parameters(recurse) if self.config.train_image_encoder: