mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-11 13:39:50 +00:00
DFE tweaks. Adding support for more llms as text encoders
This commit is contained in:
@@ -479,15 +479,13 @@ class ModelConfig:
|
||||
self.only_if_contains: Optional[List[str]] = kwargs.get("only_if_contains", None)
|
||||
self.quantize_kwargs = kwargs.get("quantize_kwargs", {})
|
||||
|
||||
if self.ignore_if_contains is not None or self.only_if_contains is not None:
|
||||
if not self.is_flux:
|
||||
raise ValueError("ignore_if_contains and only_if_contains are only supported with flux models currently")
|
||||
|
||||
# splits the model over the available gpus WIP
|
||||
self.split_model_over_gpus = kwargs.get("split_model_over_gpus", False)
|
||||
if self.split_model_over_gpus and not self.is_flux:
|
||||
raise ValueError("split_model_over_gpus is only supported with flux models currently")
|
||||
self.split_model_other_module_param_count_scale = kwargs.get("split_model_other_module_param_count_scale", 0.3)
|
||||
|
||||
self.te_name_or_path = kwargs.get("te_name_or_path", None)
|
||||
|
||||
|
||||
class EMAConfig:
|
||||
|
||||
@@ -240,7 +240,8 @@ class DiffusionFeatureExtractor3(nn.Module):
|
||||
timesteps,
|
||||
batch: DataLoaderBatchDTO,
|
||||
scheduler: CustomFlowMatchEulerDiscreteScheduler,
|
||||
lpips_weight=1.0,
|
||||
# lpips_weight=1.0,
|
||||
lpips_weight=10.0,
|
||||
clip_weight=0.1,
|
||||
pixel_weight=0.1
|
||||
):
|
||||
@@ -286,30 +287,35 @@ class DiffusionFeatureExtractor3(nn.Module):
|
||||
# go from -1 to 1 to 0 to 1
|
||||
target_img = (target_img + 1) / 2
|
||||
lpips_feat_list_target = self.get_lpips_features(target_img.float())
|
||||
target_clip_output = self.get_siglip_features(target_img).detach()
|
||||
if clip_weight > 0:
|
||||
target_clip_output = self.get_siglip_features(target_img).detach()
|
||||
if clip_weight > 0:
|
||||
pred_clip_output = self.get_siglip_features(pred_images)
|
||||
clip_loss = torch.nn.functional.mse_loss(
|
||||
pred_clip_output.float(), target_clip_output.float()
|
||||
) * clip_weight
|
||||
|
||||
if 'clip_loss' not in self.losses:
|
||||
self.losses['clip_loss'] = clip_loss.item()
|
||||
else:
|
||||
self.losses['clip_loss'] += clip_loss.item()
|
||||
|
||||
total_loss += clip_loss
|
||||
|
||||
pred_clip_output = self.get_siglip_features(pred_images)
|
||||
clip_loss = torch.nn.functional.mse_loss(
|
||||
pred_clip_output.float(), target_clip_output.float()
|
||||
) * clip_weight
|
||||
|
||||
if 'clip_loss' not in self.losses:
|
||||
self.losses['clip_loss'] = clip_loss.item()
|
||||
else:
|
||||
self.losses['clip_loss'] += clip_loss.item()
|
||||
|
||||
total_loss += clip_loss
|
||||
skip_lpips_layers = []
|
||||
|
||||
lpips_loss = 0
|
||||
for idx, lpips_feat in enumerate(lpips_feat_list_pred):
|
||||
if idx in skip_lpips_layers:
|
||||
continue
|
||||
lpips_loss += torch.nn.functional.mse_loss(
|
||||
lpips_feat.float(), lpips_feat_list_target[idx].float()
|
||||
) * lpips_weight
|
||||
|
||||
if 'lpips_loss' not in self.losses:
|
||||
self.losses['lpips_loss'] = lpips_loss.item()
|
||||
else:
|
||||
self.losses['lpips_loss'] += lpips_loss.item()
|
||||
if f'lpips_loss_{idx}' not in self.losses:
|
||||
self.losses[f'lpips_loss_{idx}'] = lpips_loss.item()
|
||||
else:
|
||||
self.losses[f'lpips_loss_{idx}'] += lpips_loss.item()
|
||||
|
||||
total_loss += lpips_loss
|
||||
|
||||
|
||||
@@ -475,6 +475,13 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
attention_mask: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
|
||||
hidden_size = self.config.get("hidden_size", 2304)
|
||||
# pad or slice text encoder
|
||||
if encoder_hidden_states.shape[2] > hidden_size:
|
||||
encoder_hidden_states = encoder_hidden_states[:, :, :hidden_size]
|
||||
elif encoder_hidden_states.shape[2] < hidden_size:
|
||||
encoder_hidden_states = F.pad(encoder_hidden_states, (0, hidden_size - encoder_hidden_states.shape[2]))
|
||||
|
||||
batch_size = hidden_states.size(0)
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ from toolkit.accelerator import get_accelerator, unwrap_model
|
||||
from typing import TYPE_CHECKING
|
||||
from toolkit.print import print_acc
|
||||
from diffusers import FluxFillPipeline
|
||||
from transformers import AutoModel, AutoTokenizer, Gemma2Model
|
||||
from transformers import AutoModel, AutoTokenizer, Gemma2Model, Qwen2Model, LlamaModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
@@ -801,10 +801,15 @@ class StableDiffusion:
|
||||
print_acc("Loading vae")
|
||||
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
|
||||
flush()
|
||||
|
||||
print_acc("Loading Gemma2")
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype)
|
||||
text_encoder = AutoModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype)
|
||||
|
||||
if self.model_config.te_name_or_path is not None:
|
||||
print_acc("Loading TE")
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_config.te_name_or_path, torch_dtype=dtype)
|
||||
text_encoder = AutoModel.from_pretrained(self.model_config.te_name_or_path, torch_dtype=dtype)
|
||||
else:
|
||||
print_acc("Loading Gemma2")
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype)
|
||||
text_encoder = AutoModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype)
|
||||
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
flush()
|
||||
@@ -2772,6 +2777,10 @@ class StableDiffusion:
|
||||
te_has_grad = self.text_encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad
|
||||
elif isinstance(self.text_encoder, Gemma2Model):
|
||||
te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad
|
||||
elif isinstance(self.text_encoder, Qwen2Model):
|
||||
te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad
|
||||
elif isinstance(self.text_encoder, LlamaModel):
|
||||
te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad
|
||||
else:
|
||||
te_has_grad = self.text_encoder.text_model.final_layer_norm.weight.requires_grad
|
||||
|
||||
|
||||
Reference in New Issue
Block a user