DFE tweaks. Adding support for more llms as text encoders

This commit is contained in:
Jaret Burkett
2025-02-13 04:31:49 -07:00
parent 8450aca10e
commit 2622de1e01
4 changed files with 46 additions and 26 deletions

View File

@@ -240,7 +240,8 @@ class DiffusionFeatureExtractor3(nn.Module):
timesteps,
batch: DataLoaderBatchDTO,
scheduler: CustomFlowMatchEulerDiscreteScheduler,
lpips_weight=1.0,
# lpips_weight=1.0,
lpips_weight=10.0,
clip_weight=0.1,
pixel_weight=0.1
):
@@ -286,30 +287,35 @@ class DiffusionFeatureExtractor3(nn.Module):
# go from -1 to 1 to 0 to 1
target_img = (target_img + 1) / 2
lpips_feat_list_target = self.get_lpips_features(target_img.float())
target_clip_output = self.get_siglip_features(target_img).detach()
if clip_weight > 0:
target_clip_output = self.get_siglip_features(target_img).detach()
if clip_weight > 0:
pred_clip_output = self.get_siglip_features(pred_images)
clip_loss = torch.nn.functional.mse_loss(
pred_clip_output.float(), target_clip_output.float()
) * clip_weight
if 'clip_loss' not in self.losses:
self.losses['clip_loss'] = clip_loss.item()
else:
self.losses['clip_loss'] += clip_loss.item()
total_loss += clip_loss
pred_clip_output = self.get_siglip_features(pred_images)
clip_loss = torch.nn.functional.mse_loss(
pred_clip_output.float(), target_clip_output.float()
) * clip_weight
if 'clip_loss' not in self.losses:
self.losses['clip_loss'] = clip_loss.item()
else:
self.losses['clip_loss'] += clip_loss.item()
total_loss += clip_loss
skip_lpips_layers = []
lpips_loss = 0
for idx, lpips_feat in enumerate(lpips_feat_list_pred):
if idx in skip_lpips_layers:
continue
lpips_loss += torch.nn.functional.mse_loss(
lpips_feat.float(), lpips_feat_list_target[idx].float()
) * lpips_weight
if 'lpips_loss' not in self.losses:
self.losses['lpips_loss'] = lpips_loss.item()
else:
self.losses['lpips_loss'] += lpips_loss.item()
if f'lpips_loss_{idx}' not in self.losses:
self.losses[f'lpips_loss_{idx}'] = lpips_loss.item()
else:
self.losses[f'lpips_loss_{idx}'] += lpips_loss.item()
total_loss += lpips_loss

View File

@@ -475,6 +475,13 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
attention_mask: torch.Tensor,
return_dict: bool = True,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
hidden_size = self.config.get("hidden_size", 2304)
# pad or slice text encoder
if encoder_hidden_states.shape[2] > hidden_size:
encoder_hidden_states = encoder_hidden_states[:, :, :hidden_size]
elif encoder_hidden_states.shape[2] < hidden_size:
encoder_hidden_states = F.pad(encoder_hidden_states, (0, hidden_size - encoder_hidden_states.shape[2]))
batch_size = hidden_states.size(0)