mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-11 05:29:48 +00:00
Fixed issue with adapters not providing gradients with new grad activator
This commit is contained in:
@@ -1473,7 +1473,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# self.step_num = self.embedding.step
|
||||
# self.start_step = self.step_num
|
||||
params.append({
|
||||
'params': self.embedding.get_trainable_params(),
|
||||
'params': list(self.embedding.get_trainable_params()),
|
||||
'lr': self.train_config.embedding_lr
|
||||
})
|
||||
|
||||
@@ -1491,7 +1491,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
else:
|
||||
# set trainable params
|
||||
params.append({
|
||||
'params': self.adapter.parameters(),
|
||||
'params': list(self.adapter.parameters()),
|
||||
'lr': self.train_config.adapter_lr
|
||||
})
|
||||
|
||||
|
||||
@@ -1161,13 +1161,13 @@ class IPAdapter(torch.nn.Module):
|
||||
# when training just scaler, we do not train anything else
|
||||
if not self.config.train_scaler:
|
||||
param_groups.append({
|
||||
"params": self.get_non_scaler_parameters(),
|
||||
"params": list(self.get_non_scaler_parameters()),
|
||||
"lr": adapter_lr,
|
||||
})
|
||||
if self.config.train_scaler or self.config.merge_scaler:
|
||||
scaler_lr = adapter_lr if self.config.scaler_lr is None else self.config.scaler_lr
|
||||
param_groups.append({
|
||||
"params": self.get_scaler_parameters(),
|
||||
"params": list(self.get_scaler_parameters()),
|
||||
"lr": scaler_lr,
|
||||
})
|
||||
return param_groups
|
||||
|
||||
@@ -711,20 +711,30 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
self.block_scaler.requires_grad = True
|
||||
else:
|
||||
self.block_scaler = None
|
||||
|
||||
self.pool = None
|
||||
|
||||
if self.config.num_tokens is not None:
|
||||
image_encoder_state_dict = self.adapter_ref().vision_encoder.state_dict()
|
||||
# image_encoder_state_dict = self.adapter_ref().vision_encoder.state_dict()
|
||||
# max_seq_len = CLIP tokens + CLS token
|
||||
max_seq_len = 257
|
||||
if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict:
|
||||
# clip
|
||||
max_seq_len = int(
|
||||
image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
|
||||
self.resampler = MLPR(
|
||||
in_dim=self.token_size,
|
||||
in_channels=max_seq_len,
|
||||
out_dim=self.mid_size,
|
||||
out_channels=self.config.num_tokens,
|
||||
# max_seq_len = 257
|
||||
# if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict:
|
||||
# # clip
|
||||
# max_seq_len = int(
|
||||
# image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
|
||||
# self.resampler = MLPR(
|
||||
# in_dim=self.token_size,
|
||||
# in_channels=max_seq_len,
|
||||
# out_dim=self.mid_size,
|
||||
# out_channels=self.config.num_tokens,
|
||||
# )
|
||||
vision_config = self.adapter_ref().vision_encoder.config
|
||||
# sequence_length = int((vision_config.image_size / vision_config.patch_size) ** 2 + 1)
|
||||
# siglip doesnt add 1
|
||||
sequence_length = int((vision_config.image_size / vision_config.patch_size) ** 2)
|
||||
self.pool = nn.Sequential(
|
||||
nn.Conv1d(sequence_length, self.config.num_tokens, 1, bias=False),
|
||||
Norm(),
|
||||
)
|
||||
|
||||
elif self.config.image_encoder_arch == "pixtral":
|
||||
@@ -733,7 +743,6 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
out_dim=self.mid_size,
|
||||
)
|
||||
|
||||
self.pool = None
|
||||
self.sparse_autoencoder = None
|
||||
if self.config.conv_pooling:
|
||||
vision_config = self.adapter_ref().vision_encoder.config
|
||||
|
||||
Reference in New Issue
Block a user