DFE tweaks. Adding support for more llms as text encoders

This commit is contained in:
Jaret Burkett
2025-02-13 04:31:49 -07:00
parent 8450aca10e
commit 2622de1e01
4 changed files with 46 additions and 26 deletions

View File

@@ -479,15 +479,13 @@ class ModelConfig:
self.only_if_contains: Optional[List[str]] = kwargs.get("only_if_contains", None)
self.quantize_kwargs = kwargs.get("quantize_kwargs", {})
if self.ignore_if_contains is not None or self.only_if_contains is not None:
if not self.is_flux:
raise ValueError("ignore_if_contains and only_if_contains are only supported with flux models currently")
# splits the model over the available gpus WIP
self.split_model_over_gpus = kwargs.get("split_model_over_gpus", False)
if self.split_model_over_gpus and not self.is_flux:
raise ValueError("split_model_over_gpus is only supported with flux models currently")
self.split_model_other_module_param_count_scale = kwargs.get("split_model_other_module_param_count_scale", 0.3)
self.te_name_or_path = kwargs.get("te_name_or_path", None)
class EMAConfig:

View File

@@ -240,7 +240,8 @@ class DiffusionFeatureExtractor3(nn.Module):
timesteps,
batch: DataLoaderBatchDTO,
scheduler: CustomFlowMatchEulerDiscreteScheduler,
lpips_weight=1.0,
# lpips_weight=1.0,
lpips_weight=10.0,
clip_weight=0.1,
pixel_weight=0.1
):
@@ -286,30 +287,35 @@ class DiffusionFeatureExtractor3(nn.Module):
# go from -1 to 1 to 0 to 1
target_img = (target_img + 1) / 2
lpips_feat_list_target = self.get_lpips_features(target_img.float())
target_clip_output = self.get_siglip_features(target_img).detach()
if clip_weight > 0:
target_clip_output = self.get_siglip_features(target_img).detach()
if clip_weight > 0:
pred_clip_output = self.get_siglip_features(pred_images)
clip_loss = torch.nn.functional.mse_loss(
pred_clip_output.float(), target_clip_output.float()
) * clip_weight
if 'clip_loss' not in self.losses:
self.losses['clip_loss'] = clip_loss.item()
else:
self.losses['clip_loss'] += clip_loss.item()
total_loss += clip_loss
pred_clip_output = self.get_siglip_features(pred_images)
clip_loss = torch.nn.functional.mse_loss(
pred_clip_output.float(), target_clip_output.float()
) * clip_weight
if 'clip_loss' not in self.losses:
self.losses['clip_loss'] = clip_loss.item()
else:
self.losses['clip_loss'] += clip_loss.item()
total_loss += clip_loss
skip_lpips_layers = []
lpips_loss = 0
for idx, lpips_feat in enumerate(lpips_feat_list_pred):
if idx in skip_lpips_layers:
continue
lpips_loss += torch.nn.functional.mse_loss(
lpips_feat.float(), lpips_feat_list_target[idx].float()
) * lpips_weight
if 'lpips_loss' not in self.losses:
self.losses['lpips_loss'] = lpips_loss.item()
else:
self.losses['lpips_loss'] += lpips_loss.item()
if f'lpips_loss_{idx}' not in self.losses:
self.losses[f'lpips_loss_{idx}'] = lpips_loss.item()
else:
self.losses[f'lpips_loss_{idx}'] += lpips_loss.item()
total_loss += lpips_loss

View File

@@ -475,6 +475,13 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
attention_mask: torch.Tensor,
return_dict: bool = True,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
hidden_size = self.config.get("hidden_size", 2304)
# pad or slice text encoder
if encoder_hidden_states.shape[2] > hidden_size:
encoder_hidden_states = encoder_hidden_states[:, :, :hidden_size]
elif encoder_hidden_states.shape[2] < hidden_size:
encoder_hidden_states = F.pad(encoder_hidden_states, (0, hidden_size - encoder_hidden_states.shape[2]))
batch_size = hidden_states.size(0)

View File

@@ -68,7 +68,7 @@ from toolkit.accelerator import get_accelerator, unwrap_model
from typing import TYPE_CHECKING
from toolkit.print import print_acc
from diffusers import FluxFillPipeline
from transformers import AutoModel, AutoTokenizer, Gemma2Model
from transformers import AutoModel, AutoTokenizer, Gemma2Model, Qwen2Model, LlamaModel
if TYPE_CHECKING:
from toolkit.lora_special import LoRASpecialNetwork
@@ -801,10 +801,15 @@ class StableDiffusion:
print_acc("Loading vae")
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
flush()
print_acc("Loading Gemma2")
tokenizer = AutoTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype)
text_encoder = AutoModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype)
if self.model_config.te_name_or_path is not None:
print_acc("Loading TE")
tokenizer = AutoTokenizer.from_pretrained(self.model_config.te_name_or_path, torch_dtype=dtype)
text_encoder = AutoModel.from_pretrained(self.model_config.te_name_or_path, torch_dtype=dtype)
else:
print_acc("Loading Gemma2")
tokenizer = AutoTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype)
text_encoder = AutoModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype)
text_encoder.to(self.device_torch, dtype=dtype)
flush()
@@ -2772,6 +2777,10 @@ class StableDiffusion:
te_has_grad = self.text_encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad
elif isinstance(self.text_encoder, Gemma2Model):
te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad
elif isinstance(self.text_encoder, Qwen2Model):
te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad
elif isinstance(self.text_encoder, LlamaModel):
te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad
else:
te_has_grad = self.text_encoder.text_model.final_layer_norm.weight.requires_grad