mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Fixed issue with adapters not providing gradients with new grad activator
This commit is contained in:
@@ -1473,7 +1473,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
# self.step_num = self.embedding.step
|
# self.step_num = self.embedding.step
|
||||||
# self.start_step = self.step_num
|
# self.start_step = self.step_num
|
||||||
params.append({
|
params.append({
|
||||||
'params': self.embedding.get_trainable_params(),
|
'params': list(self.embedding.get_trainable_params()),
|
||||||
'lr': self.train_config.embedding_lr
|
'lr': self.train_config.embedding_lr
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -1491,7 +1491,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
else:
|
else:
|
||||||
# set trainable params
|
# set trainable params
|
||||||
params.append({
|
params.append({
|
||||||
'params': self.adapter.parameters(),
|
'params': list(self.adapter.parameters()),
|
||||||
'lr': self.train_config.adapter_lr
|
'lr': self.train_config.adapter_lr
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -1161,13 +1161,13 @@ class IPAdapter(torch.nn.Module):
|
|||||||
# when training just scaler, we do not train anything else
|
# when training just scaler, we do not train anything else
|
||||||
if not self.config.train_scaler:
|
if not self.config.train_scaler:
|
||||||
param_groups.append({
|
param_groups.append({
|
||||||
"params": self.get_non_scaler_parameters(),
|
"params": list(self.get_non_scaler_parameters()),
|
||||||
"lr": adapter_lr,
|
"lr": adapter_lr,
|
||||||
})
|
})
|
||||||
if self.config.train_scaler or self.config.merge_scaler:
|
if self.config.train_scaler or self.config.merge_scaler:
|
||||||
scaler_lr = adapter_lr if self.config.scaler_lr is None else self.config.scaler_lr
|
scaler_lr = adapter_lr if self.config.scaler_lr is None else self.config.scaler_lr
|
||||||
param_groups.append({
|
param_groups.append({
|
||||||
"params": self.get_scaler_parameters(),
|
"params": list(self.get_scaler_parameters()),
|
||||||
"lr": scaler_lr,
|
"lr": scaler_lr,
|
||||||
})
|
})
|
||||||
return param_groups
|
return param_groups
|
||||||
|
|||||||
@@ -711,20 +711,30 @@ class VisionDirectAdapter(torch.nn.Module):
|
|||||||
self.block_scaler.requires_grad = True
|
self.block_scaler.requires_grad = True
|
||||||
else:
|
else:
|
||||||
self.block_scaler = None
|
self.block_scaler = None
|
||||||
|
|
||||||
|
self.pool = None
|
||||||
|
|
||||||
if self.config.num_tokens is not None:
|
if self.config.num_tokens is not None:
|
||||||
image_encoder_state_dict = self.adapter_ref().vision_encoder.state_dict()
|
# image_encoder_state_dict = self.adapter_ref().vision_encoder.state_dict()
|
||||||
# max_seq_len = CLIP tokens + CLS token
|
# max_seq_len = CLIP tokens + CLS token
|
||||||
max_seq_len = 257
|
# max_seq_len = 257
|
||||||
if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict:
|
# if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict:
|
||||||
# clip
|
# # clip
|
||||||
max_seq_len = int(
|
# max_seq_len = int(
|
||||||
image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
|
# image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
|
||||||
self.resampler = MLPR(
|
# self.resampler = MLPR(
|
||||||
in_dim=self.token_size,
|
# in_dim=self.token_size,
|
||||||
in_channels=max_seq_len,
|
# in_channels=max_seq_len,
|
||||||
out_dim=self.mid_size,
|
# out_dim=self.mid_size,
|
||||||
out_channels=self.config.num_tokens,
|
# out_channels=self.config.num_tokens,
|
||||||
|
# )
|
||||||
|
vision_config = self.adapter_ref().vision_encoder.config
|
||||||
|
# sequence_length = int((vision_config.image_size / vision_config.patch_size) ** 2 + 1)
|
||||||
|
# siglip doesnt add 1
|
||||||
|
sequence_length = int((vision_config.image_size / vision_config.patch_size) ** 2)
|
||||||
|
self.pool = nn.Sequential(
|
||||||
|
nn.Conv1d(sequence_length, self.config.num_tokens, 1, bias=False),
|
||||||
|
Norm(),
|
||||||
)
|
)
|
||||||
|
|
||||||
elif self.config.image_encoder_arch == "pixtral":
|
elif self.config.image_encoder_arch == "pixtral":
|
||||||
@@ -733,7 +743,6 @@ class VisionDirectAdapter(torch.nn.Module):
|
|||||||
out_dim=self.mid_size,
|
out_dim=self.mid_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.pool = None
|
|
||||||
self.sparse_autoencoder = None
|
self.sparse_autoencoder = None
|
||||||
if self.config.conv_pooling:
|
if self.config.conv_pooling:
|
||||||
vision_config = self.adapter_ref().vision_encoder.config
|
vision_config = self.adapter_ref().vision_encoder.config
|
||||||
|
|||||||
Reference in New Issue
Block a user