Fixed issue with adapters not providing gradients with new grad activator

This commit is contained in:
Jaret Burkett
2024-10-29 14:22:10 -06:00
parent 22cd40d7b9
commit 4747716867
3 changed files with 25 additions and 16 deletions

View File

@@ -1473,7 +1473,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# self.step_num = self.embedding.step
# self.start_step = self.step_num
params.append({
'params': self.embedding.get_trainable_params(),
'params': list(self.embedding.get_trainable_params()),
'lr': self.train_config.embedding_lr
})
@@ -1491,7 +1491,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
else:
# set trainable params
params.append({
'params': self.adapter.parameters(),
'params': list(self.adapter.parameters()),
'lr': self.train_config.adapter_lr
})

View File

@@ -1161,13 +1161,13 @@ class IPAdapter(torch.nn.Module):
# when training just scaler, we do not train anything else
if not self.config.train_scaler:
param_groups.append({
"params": self.get_non_scaler_parameters(),
"params": list(self.get_non_scaler_parameters()),
"lr": adapter_lr,
})
if self.config.train_scaler or self.config.merge_scaler:
scaler_lr = adapter_lr if self.config.scaler_lr is None else self.config.scaler_lr
param_groups.append({
"params": self.get_scaler_parameters(),
"params": list(self.get_scaler_parameters()),
"lr": scaler_lr,
})
return param_groups

View File

@@ -711,20 +711,30 @@ class VisionDirectAdapter(torch.nn.Module):
self.block_scaler.requires_grad = True
else:
self.block_scaler = None
self.pool = None
if self.config.num_tokens is not None:
image_encoder_state_dict = self.adapter_ref().vision_encoder.state_dict()
# image_encoder_state_dict = self.adapter_ref().vision_encoder.state_dict()
# max_seq_len = CLIP tokens + CLS token
max_seq_len = 257
if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict:
# clip
max_seq_len = int(
image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
self.resampler = MLPR(
in_dim=self.token_size,
in_channels=max_seq_len,
out_dim=self.mid_size,
out_channels=self.config.num_tokens,
# max_seq_len = 257
# if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict:
# # clip
# max_seq_len = int(
# image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
# self.resampler = MLPR(
# in_dim=self.token_size,
# in_channels=max_seq_len,
# out_dim=self.mid_size,
# out_channels=self.config.num_tokens,
# )
vision_config = self.adapter_ref().vision_encoder.config
# sequence_length = int((vision_config.image_size / vision_config.patch_size) ** 2 + 1)
# siglip doesnt add 1
sequence_length = int((vision_config.image_size / vision_config.patch_size) ** 2)
self.pool = nn.Sequential(
nn.Conv1d(sequence_length, self.config.num_tokens, 1, bias=False),
Norm(),
)
elif self.config.image_encoder_arch == "pixtral":
@@ -733,7 +743,6 @@ class VisionDirectAdapter(torch.nn.Module):
out_dim=self.mid_size,
)
self.pool = None
self.sparse_autoencoder = None
if self.config.conv_pooling:
vision_config = self.adapter_ref().vision_encoder.config