mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Fixed issue with adapters not providing gradients with new grad activator
This commit is contained in:
@@ -711,20 +711,30 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
self.block_scaler.requires_grad = True
|
||||
else:
|
||||
self.block_scaler = None
|
||||
|
||||
self.pool = None
|
||||
|
||||
if self.config.num_tokens is not None:
|
||||
image_encoder_state_dict = self.adapter_ref().vision_encoder.state_dict()
|
||||
# image_encoder_state_dict = self.adapter_ref().vision_encoder.state_dict()
|
||||
# max_seq_len = CLIP tokens + CLS token
|
||||
max_seq_len = 257
|
||||
if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict:
|
||||
# clip
|
||||
max_seq_len = int(
|
||||
image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
|
||||
self.resampler = MLPR(
|
||||
in_dim=self.token_size,
|
||||
in_channels=max_seq_len,
|
||||
out_dim=self.mid_size,
|
||||
out_channels=self.config.num_tokens,
|
||||
# max_seq_len = 257
|
||||
# if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict:
|
||||
# # clip
|
||||
# max_seq_len = int(
|
||||
# image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
|
||||
# self.resampler = MLPR(
|
||||
# in_dim=self.token_size,
|
||||
# in_channels=max_seq_len,
|
||||
# out_dim=self.mid_size,
|
||||
# out_channels=self.config.num_tokens,
|
||||
# )
|
||||
vision_config = self.adapter_ref().vision_encoder.config
|
||||
# sequence_length = int((vision_config.image_size / vision_config.patch_size) ** 2 + 1)
|
||||
# siglip doesnt add 1
|
||||
sequence_length = int((vision_config.image_size / vision_config.patch_size) ** 2)
|
||||
self.pool = nn.Sequential(
|
||||
nn.Conv1d(sequence_length, self.config.num_tokens, 1, bias=False),
|
||||
Norm(),
|
||||
)
|
||||
|
||||
elif self.config.image_encoder_arch == "pixtral":
|
||||
@@ -733,7 +743,6 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
out_dim=self.mid_size,
|
||||
)
|
||||
|
||||
self.pool = None
|
||||
self.sparse_autoencoder = None
|
||||
if self.config.conv_pooling:
|
||||
vision_config = self.adapter_ref().vision_encoder.config
|
||||
|
||||
Reference in New Issue
Block a user