diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index 75fac738..e594378b 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -445,8 +445,8 @@ class CustomAdapter(torch.nn.Module): return state_dict elif self.adapter_type == 'vision_direct': state_dict["dvadapter"] = self.vd_adapter.state_dict() - if self.config.train_image_encoder: - state_dict["vision_encoder"] = self.vision_encoder.state_dict() + # if self.config.train_image_encoder: # always return vision encoder + state_dict["vision_encoder"] = self.vision_encoder.state_dict() return state_dict elif self.adapter_type == 'single_value': state_dict["sv_adapter"] = self.single_value_adapter.state_dict()