diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 8619f559..284f4fdb 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -479,15 +479,13 @@ class ModelConfig: self.only_if_contains: Optional[List[str]] = kwargs.get("only_if_contains", None) self.quantize_kwargs = kwargs.get("quantize_kwargs", {}) - if self.ignore_if_contains is not None or self.only_if_contains is not None: - if not self.is_flux: - raise ValueError("ignore_if_contains and only_if_contains are only supported with flux models currently") - # splits the model over the available gpus WIP self.split_model_over_gpus = kwargs.get("split_model_over_gpus", False) if self.split_model_over_gpus and not self.is_flux: raise ValueError("split_model_over_gpus is only supported with flux models currently") self.split_model_other_module_param_count_scale = kwargs.get("split_model_other_module_param_count_scale", 0.3) + + self.te_name_or_path = kwargs.get("te_name_or_path", None) class EMAConfig: diff --git a/toolkit/models/diffusion_feature_extraction.py b/toolkit/models/diffusion_feature_extraction.py index 88ffbd0b..29becb69 100644 --- a/toolkit/models/diffusion_feature_extraction.py +++ b/toolkit/models/diffusion_feature_extraction.py @@ -240,7 +240,8 @@ class DiffusionFeatureExtractor3(nn.Module): timesteps, batch: DataLoaderBatchDTO, scheduler: CustomFlowMatchEulerDiscreteScheduler, - lpips_weight=1.0, + # lpips_weight=1.0, + lpips_weight=10.0, clip_weight=0.1, pixel_weight=0.1 ): @@ -286,30 +287,35 @@ class DiffusionFeatureExtractor3(nn.Module): # go from -1 to 1 to 0 to 1 target_img = (target_img + 1) / 2 lpips_feat_list_target = self.get_lpips_features(target_img.float()) - target_clip_output = self.get_siglip_features(target_img).detach() + if clip_weight > 0: + target_clip_output = self.get_siglip_features(target_img).detach() + if clip_weight > 0: + pred_clip_output = self.get_siglip_features(pred_images) + clip_loss = torch.nn.functional.mse_loss( + pred_clip_output.float(), target_clip_output.float() + ) * clip_weight + + if 'clip_loss' not in self.losses: + self.losses['clip_loss'] = clip_loss.item() + else: + self.losses['clip_loss'] += clip_loss.item() + + total_loss += clip_loss - pred_clip_output = self.get_siglip_features(pred_images) - clip_loss = torch.nn.functional.mse_loss( - pred_clip_output.float(), target_clip_output.float() - ) * clip_weight - - if 'clip_loss' not in self.losses: - self.losses['clip_loss'] = clip_loss.item() - else: - self.losses['clip_loss'] += clip_loss.item() - - total_loss += clip_loss + skip_lpips_layers = [] lpips_loss = 0 for idx, lpips_feat in enumerate(lpips_feat_list_pred): + if idx in skip_lpips_layers: + continue lpips_loss += torch.nn.functional.mse_loss( lpips_feat.float(), lpips_feat_list_target[idx].float() ) * lpips_weight - if 'lpips_loss' not in self.losses: - self.losses['lpips_loss'] = lpips_loss.item() - else: - self.losses['lpips_loss'] += lpips_loss.item() + if f'lpips_loss_{idx}' not in self.losses: + self.losses[f'lpips_loss_{idx}'] = lpips_loss.item() + else: + self.losses[f'lpips_loss_{idx}'] += lpips_loss.item() total_loss += lpips_loss diff --git a/toolkit/models/lumina2.py b/toolkit/models/lumina2.py index f26e90ca..628f28b8 100644 --- a/toolkit/models/lumina2.py +++ b/toolkit/models/lumina2.py @@ -475,6 +475,13 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): attention_mask: torch.Tensor, return_dict: bool = True, ) -> Union[torch.Tensor, Transformer2DModelOutput]: + + hidden_size = self.config.get("hidden_size", 2304) + # pad or slice text encoder + if encoder_hidden_states.shape[2] > hidden_size: + encoder_hidden_states = encoder_hidden_states[:, :, :hidden_size] + elif encoder_hidden_states.shape[2] < hidden_size: + encoder_hidden_states = F.pad(encoder_hidden_states, (0, hidden_size - encoder_hidden_states.shape[2])) batch_size = hidden_states.size(0) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 132890a1..8b2a6506 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -68,7 +68,7 @@ from toolkit.accelerator import get_accelerator, unwrap_model from typing import TYPE_CHECKING from toolkit.print import print_acc from diffusers import FluxFillPipeline -from transformers import AutoModel, AutoTokenizer, Gemma2Model +from transformers import AutoModel, AutoTokenizer, Gemma2Model, Qwen2Model, LlamaModel if TYPE_CHECKING: from toolkit.lora_special import LoRASpecialNetwork @@ -801,10 +801,15 @@ class StableDiffusion: print_acc("Loading vae") vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype) flush() - - print_acc("Loading Gemma2") - tokenizer = AutoTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype) - text_encoder = AutoModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype) + + if self.model_config.te_name_or_path is not None: + print_acc("Loading TE") + tokenizer = AutoTokenizer.from_pretrained(self.model_config.te_name_or_path, torch_dtype=dtype) + text_encoder = AutoModel.from_pretrained(self.model_config.te_name_or_path, torch_dtype=dtype) + else: + print_acc("Loading Gemma2") + tokenizer = AutoTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype) + text_encoder = AutoModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype) text_encoder.to(self.device_torch, dtype=dtype) flush() @@ -2772,6 +2777,10 @@ class StableDiffusion: te_has_grad = self.text_encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad elif isinstance(self.text_encoder, Gemma2Model): te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad + elif isinstance(self.text_encoder, Qwen2Model): + te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad + elif isinstance(self.text_encoder, LlamaModel): + te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad else: te_has_grad = self.text_encoder.text_model.final_layer_norm.weight.requires_grad