mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
DFE tweaks. Adding support for more llms as text encoders
This commit is contained in:
@@ -240,7 +240,8 @@ class DiffusionFeatureExtractor3(nn.Module):
|
||||
timesteps,
|
||||
batch: DataLoaderBatchDTO,
|
||||
scheduler: CustomFlowMatchEulerDiscreteScheduler,
|
||||
lpips_weight=1.0,
|
||||
# lpips_weight=1.0,
|
||||
lpips_weight=10.0,
|
||||
clip_weight=0.1,
|
||||
pixel_weight=0.1
|
||||
):
|
||||
@@ -286,30 +287,35 @@ class DiffusionFeatureExtractor3(nn.Module):
|
||||
# go from -1 to 1 to 0 to 1
|
||||
target_img = (target_img + 1) / 2
|
||||
lpips_feat_list_target = self.get_lpips_features(target_img.float())
|
||||
target_clip_output = self.get_siglip_features(target_img).detach()
|
||||
if clip_weight > 0:
|
||||
target_clip_output = self.get_siglip_features(target_img).detach()
|
||||
if clip_weight > 0:
|
||||
pred_clip_output = self.get_siglip_features(pred_images)
|
||||
clip_loss = torch.nn.functional.mse_loss(
|
||||
pred_clip_output.float(), target_clip_output.float()
|
||||
) * clip_weight
|
||||
|
||||
if 'clip_loss' not in self.losses:
|
||||
self.losses['clip_loss'] = clip_loss.item()
|
||||
else:
|
||||
self.losses['clip_loss'] += clip_loss.item()
|
||||
|
||||
total_loss += clip_loss
|
||||
|
||||
pred_clip_output = self.get_siglip_features(pred_images)
|
||||
clip_loss = torch.nn.functional.mse_loss(
|
||||
pred_clip_output.float(), target_clip_output.float()
|
||||
) * clip_weight
|
||||
|
||||
if 'clip_loss' not in self.losses:
|
||||
self.losses['clip_loss'] = clip_loss.item()
|
||||
else:
|
||||
self.losses['clip_loss'] += clip_loss.item()
|
||||
|
||||
total_loss += clip_loss
|
||||
skip_lpips_layers = []
|
||||
|
||||
lpips_loss = 0
|
||||
for idx, lpips_feat in enumerate(lpips_feat_list_pred):
|
||||
if idx in skip_lpips_layers:
|
||||
continue
|
||||
lpips_loss += torch.nn.functional.mse_loss(
|
||||
lpips_feat.float(), lpips_feat_list_target[idx].float()
|
||||
) * lpips_weight
|
||||
|
||||
if 'lpips_loss' not in self.losses:
|
||||
self.losses['lpips_loss'] = lpips_loss.item()
|
||||
else:
|
||||
self.losses['lpips_loss'] += lpips_loss.item()
|
||||
if f'lpips_loss_{idx}' not in self.losses:
|
||||
self.losses[f'lpips_loss_{idx}'] = lpips_loss.item()
|
||||
else:
|
||||
self.losses[f'lpips_loss_{idx}'] += lpips_loss.item()
|
||||
|
||||
total_loss += lpips_loss
|
||||
|
||||
|
||||
@@ -475,6 +475,13 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
attention_mask: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
|
||||
hidden_size = self.config.get("hidden_size", 2304)
|
||||
# pad or slice text encoder
|
||||
if encoder_hidden_states.shape[2] > hidden_size:
|
||||
encoder_hidden_states = encoder_hidden_states[:, :, :hidden_size]
|
||||
elif encoder_hidden_states.shape[2] < hidden_size:
|
||||
encoder_hidden_states = F.pad(encoder_hidden_states, (0, hidden_size - encoder_hidden_states.shape[2]))
|
||||
|
||||
batch_size = hidden_states.size(0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user