diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 76e49b04..d9a0fb8c 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1473,7 +1473,7 @@ class BaseSDTrainProcess(BaseTrainProcess): # self.step_num = self.embedding.step # self.start_step = self.step_num params.append({ - 'params': self.embedding.get_trainable_params(), + 'params': list(self.embedding.get_trainable_params()), 'lr': self.train_config.embedding_lr }) @@ -1491,7 +1491,7 @@ class BaseSDTrainProcess(BaseTrainProcess): else: # set trainable params params.append({ - 'params': self.adapter.parameters(), + 'params': list(self.adapter.parameters()), 'lr': self.train_config.adapter_lr }) diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index beacb203..4821e968 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -1161,13 +1161,13 @@ class IPAdapter(torch.nn.Module): # when training just scaler, we do not train anything else if not self.config.train_scaler: param_groups.append({ - "params": self.get_non_scaler_parameters(), + "params": list(self.get_non_scaler_parameters()), "lr": adapter_lr, }) if self.config.train_scaler or self.config.merge_scaler: scaler_lr = adapter_lr if self.config.scaler_lr is None else self.config.scaler_lr param_groups.append({ - "params": self.get_scaler_parameters(), + "params": list(self.get_scaler_parameters()), "lr": scaler_lr, }) return param_groups diff --git a/toolkit/models/vd_adapter.py b/toolkit/models/vd_adapter.py index 946a1e10..ea3f9bc7 100644 --- a/toolkit/models/vd_adapter.py +++ b/toolkit/models/vd_adapter.py @@ -711,20 +711,30 @@ class VisionDirectAdapter(torch.nn.Module): self.block_scaler.requires_grad = True else: self.block_scaler = None + + self.pool = None if self.config.num_tokens is not None: - image_encoder_state_dict = self.adapter_ref().vision_encoder.state_dict() + # image_encoder_state_dict = self.adapter_ref().vision_encoder.state_dict() # max_seq_len = CLIP tokens + CLS token - max_seq_len = 257 - if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict: - # clip - max_seq_len = int( - image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0]) - self.resampler = MLPR( - in_dim=self.token_size, - in_channels=max_seq_len, - out_dim=self.mid_size, - out_channels=self.config.num_tokens, + # max_seq_len = 257 + # if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict: + # # clip + # max_seq_len = int( + # image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0]) + # self.resampler = MLPR( + # in_dim=self.token_size, + # in_channels=max_seq_len, + # out_dim=self.mid_size, + # out_channels=self.config.num_tokens, + # ) + vision_config = self.adapter_ref().vision_encoder.config + # sequence_length = int((vision_config.image_size / vision_config.patch_size) ** 2 + 1) + # siglip doesnt add 1 + sequence_length = int((vision_config.image_size / vision_config.patch_size) ** 2) + self.pool = nn.Sequential( + nn.Conv1d(sequence_length, self.config.num_tokens, 1, bias=False), + Norm(), ) elif self.config.image_encoder_arch == "pixtral": @@ -733,7 +743,6 @@ class VisionDirectAdapter(torch.nn.Module): out_dim=self.mid_size, ) - self.pool = None self.sparse_autoencoder = None if self.config.conv_pooling: vision_config = self.adapter_ref().vision_encoder.config