Small fixed for DFE, polar guidance, and other things

This commit is contained in:
Jaret Burkett
2025-02-12 09:27:44 -07:00
parent 10aa7e9d5e
commit 787bb37e76
7 changed files with 87 additions and 43 deletions

View File

@@ -404,13 +404,14 @@ class SDTrainer(BaseSDTrainProcess):
additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight * 100.0
elif self.dfe.version == 3:
dfe_loss = self.dfe(
noise=noise,
noise_pred=noise_pred,
noisy_latents=noisy_latents,
timesteps=timesteps,
batch=batch,
scheduler=self.sd.noise_scheduler
)
additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight
additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight
else:
raise ValueError(f"Unknown diffusion feature extractor version {self.dfe.version}")
@@ -563,6 +564,7 @@ class SDTrainer(BaseSDTrainProcess):
noise=noise,
sd=self.sd,
unconditional_embeds=unconditional_embeds,
train_config=self.train_config,
**kwargs
)

View File

@@ -1387,12 +1387,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.load_training_state_from_metadata(latest_save_path)
# get the noise scheduler
arch = 'sd'
if self.model_config.is_pixart:
arch = 'pixart'
if self.model_config.is_flux:
arch = 'flux'
sampler = get_sampler(
self.train_config.noise_scheduler,
{
"prediction_type": "v_prediction" if self.model_config.is_v_pred else "epsilon",
},
'sd' if not self.model_config.is_pixart else 'pixart'
arch
)
if self.train_config.train_refiner and self.model_config.refiner_name_or_path is not None and self.network_config is None:

View File

@@ -403,7 +403,7 @@ class TrainConfig:
# diffusion feature extractor
self.diffusion_feature_extractor_path = kwargs.get('diffusion_feature_extractor_path', None)
self.diffusion_feature_extractor_weight = kwargs.get('diffusion_feature_extractor_weight', 0.1)
self.diffusion_feature_extractor_weight = kwargs.get('diffusion_feature_extractor_weight', 1.0)
# optimal noise pairing
self.optimal_noise_pairing_samples = kwargs.get('optimal_noise_pairing_samples', 1)

View File

@@ -6,6 +6,7 @@ from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
from toolkit.stable_diffusion_model import StableDiffusion
from toolkit.train_tools import get_torch_dtype
from toolkit.config_modules import TrainConfig
GuidanceType = Literal["targeted", "polarity", "targeted_polarity", "direct"]
@@ -407,6 +408,7 @@ def get_guided_loss_polarity(
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
sd: 'StableDiffusion',
train_config: 'TrainConfig',
scaler=None,
**kwargs
):
@@ -423,8 +425,22 @@ def get_guided_loss_polarity(
target_neg = noise
if sd.is_flow_matching:
# set the timesteps for flow matching as linear since we will do weighing
sd.noise_scheduler.set_train_timesteps(1000, device, linear=True)
linear_timesteps = any([
train_config.linear_timesteps,
train_config.linear_timesteps2,
train_config.timestep_type == 'linear',
])
timestep_type = 'linear' if linear_timesteps else None
if timestep_type is None:
timestep_type = train_config.timestep_type
sd.noise_scheduler.set_train_timesteps(
1000,
device=device,
timestep_type=timestep_type,
latents=conditional_latents
)
target_pos = (noise - conditional_latents).detach()
target_neg = (noise - unconditional_latents).detach()
@@ -481,11 +497,6 @@ def get_guided_loss_polarity(
loss = pred_loss + pred_neg_loss
# if sd.is_flow_matching:
# timestep_weight = sd.noise_scheduler.get_weights_for_timesteps(timesteps).to(loss.device, dtype=loss.dtype).detach()
# loss = loss * timestep_weight
loss = loss.mean([1, 2, 3])
loss = loss.mean()
if scaler is not None:
@@ -609,6 +620,7 @@ def get_guidance_loss(
mask_multiplier=None,
prior_pred=None,
scaler=None,
train_config=None,
**kwargs
):
# TODO add others and process individual batch items separately
@@ -641,6 +653,7 @@ def get_guidance_loss(
noise,
sd,
scaler=scaler,
train_config=train_config,
**kwargs
)
elif guidance_type == "tnt":

View File

@@ -226,45 +226,48 @@ class DiffusionFeatureExtractor3(nn.Module):
return feats_list
# do lpips
lpips_feat_list = [x.detach() for x in get_lpips_features(
lpips_feat_list = [x for x in get_lpips_features(
tensors_n1p1.to(device, dtype=torch.float32))]
return lpips_feat_list
def forward(
self,
self,
noise,
noise_pred,
noisy_latents,
timesteps,
batch: DataLoaderBatchDTO,
scheduler: CustomFlowMatchEulerDiscreteScheduler,
lpips_weight=20.0,
lpips_weight=1.0,
clip_weight=0.1,
pixel_weight=1.0
pixel_weight=0.1
):
dtype = torch.bfloat16
device = self.vae.device
# first we step the scheduler from current timestep to the very end for a full denoise
bs = noise_pred.shape[0]
noise_pred_chunks = torch.chunk(noise_pred, bs)
timestep_chunks = torch.chunk(timesteps, bs)
noisy_latent_chunks = torch.chunk(noisy_latents, bs)
stepped_chunks = []
for idx in range(bs):
model_output = noise_pred_chunks[idx]
timestep = timestep_chunks[idx]
scheduler._step_index = None
scheduler._init_step_index(timestep)
sample = noisy_latent_chunks[idx].to(torch.float32)
# bs = noise_pred.shape[0]
# noise_pred_chunks = torch.chunk(noise_pred, bs)
# timestep_chunks = torch.chunk(timesteps, bs)
# noisy_latent_chunks = torch.chunk(noisy_latents, bs)
# stepped_chunks = []
# for idx in range(bs):
# model_output = noise_pred_chunks[idx]
# timestep = timestep_chunks[idx]
# scheduler._step_index = None
# scheduler._init_step_index(timestep)
# sample = noisy_latent_chunks[idx].to(torch.float32)
sigma = scheduler.sigmas[scheduler.step_index]
sigma_next = scheduler.sigmas[-1] # use last sigma for final step
prev_sample = sample + (sigma_next - sigma) * model_output
stepped_chunks.append(prev_sample)
# sigma = scheduler.sigmas[scheduler.step_index]
# sigma_next = scheduler.sigmas[-1] # use last sigma for final step
# prev_sample = sample + (sigma_next - sigma) * model_output
# stepped_chunks.append(prev_sample)
stepped_latents = torch.cat(stepped_chunks, dim=0)
# stepped_latents = torch.cat(stepped_chunks, dim=0)
stepped_latents = noise - noise_pred
latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype)
@@ -274,16 +277,18 @@ class DiffusionFeatureExtractor3(nn.Module):
pred_images = (tensors_n1p1 + 1) / 2 # 0 to 1
pred_clip_output = self.get_siglip_features(pred_images)
lpips_feat_list_pred = self.get_lpips_features(pred_images.float())
total_loss = 0
with torch.no_grad():
target_img = batch.tensor.to(device, dtype=dtype)
# go from -1 to 1 to 0 to 1
target_img = (target_img + 1) / 2
target_clip_output = self.get_siglip_features(target_img).detach()
lpips_feat_list_target = self.get_lpips_features(target_img.float())
target_clip_output = self.get_siglip_features(target_img).detach()
pred_clip_output = self.get_siglip_features(pred_images)
clip_loss = torch.nn.functional.mse_loss(
pred_clip_output.float(), target_clip_output.float()
) * clip_weight
@@ -293,7 +298,7 @@ class DiffusionFeatureExtractor3(nn.Module):
else:
self.losses['clip_loss'] += clip_loss.item()
total_loss = clip_loss
total_loss += clip_loss
lpips_loss = 0
for idx, lpips_feat in enumerate(lpips_feat_list_pred):
@@ -308,14 +313,14 @@ class DiffusionFeatureExtractor3(nn.Module):
total_loss += lpips_loss
mse_loss = torch.nn.functional.mse_loss(
stepped_latents.float(), batch.latents.float()
) * pixel_weight
# mse_loss = torch.nn.functional.mse_loss(
# stepped_latents.float(), batch.latents.float()
# ) * pixel_weight
if 'pixel_loss' not in self.losses:
self.losses['pixel_loss'] = mse_loss.item()
else:
self.losses['pixel_loss'] += mse_loss.item()
# if 'pixel_loss' not in self.losses:
# self.losses['pixel_loss'] = mse_loss.item()
# else:
# self.losses['pixel_loss'] += mse_loss.item()
if self.step % self.log_every == 0 and self.step > 0:
print(f"DFE losses:")
@@ -325,7 +330,7 @@ class DiffusionFeatureExtractor3(nn.Module):
print(f" - {key}: {self.losses[key]:.3e}")
self.losses[key] = 0.0
total_loss += mse_loss
# total_loss += mse_loss
self.step += 1
return total_loss

View File

@@ -88,6 +88,18 @@ flux_config = {
"use_dynamic_shifting": True
}
sd_flow_config = {
"_class_name": "FlowMatchEulerDiscreteScheduler",
"_diffusers_version": "0.30.0.dev0",
"base_image_seq_len": 256,
"base_shift": 0.5,
"max_image_seq_len": 4096,
"max_shift": 1.15,
"num_train_timesteps": 1000,
"shift": 3.0,
"use_dynamic_shifting": False
}
def get_sampler(
sampler: str,
@@ -133,6 +145,8 @@ def get_sampler(
elif sampler == "flowmatch":
scheduler_cls = CustomFlowMatchEulerDiscreteScheduler
config_to_use = copy.deepcopy(flux_config)
if arch == "sd":
config_to_use = copy.deepcopy(sd_flow_config)
else:
raise ValueError(f"Sampler {sampler} not supported")

View File

@@ -974,12 +974,17 @@ class StableDiffusion:
"prediction_type": self.prediction_type,
})
else:
arch = 'sd'
if self.model_config.is_pixart:
arch = 'pixart'
if self.model_config.is_flux:
arch = 'flux'
noise_scheduler = get_sampler(
sampler,
{
"prediction_type": self.prediction_type,
},
'sd' if not self.is_pixart else 'pixart'
arch
)
try: