mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Small fixed for DFE, polar guidance, and other things
This commit is contained in:
@@ -404,13 +404,14 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight * 100.0
|
additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight * 100.0
|
||||||
elif self.dfe.version == 3:
|
elif self.dfe.version == 3:
|
||||||
dfe_loss = self.dfe(
|
dfe_loss = self.dfe(
|
||||||
|
noise=noise,
|
||||||
noise_pred=noise_pred,
|
noise_pred=noise_pred,
|
||||||
noisy_latents=noisy_latents,
|
noisy_latents=noisy_latents,
|
||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
scheduler=self.sd.noise_scheduler
|
scheduler=self.sd.noise_scheduler
|
||||||
)
|
)
|
||||||
additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight
|
additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown diffusion feature extractor version {self.dfe.version}")
|
raise ValueError(f"Unknown diffusion feature extractor version {self.dfe.version}")
|
||||||
|
|
||||||
@@ -563,6 +564,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
noise=noise,
|
noise=noise,
|
||||||
sd=self.sd,
|
sd=self.sd,
|
||||||
unconditional_embeds=unconditional_embeds,
|
unconditional_embeds=unconditional_embeds,
|
||||||
|
train_config=self.train_config,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1387,12 +1387,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
self.load_training_state_from_metadata(latest_save_path)
|
self.load_training_state_from_metadata(latest_save_path)
|
||||||
|
|
||||||
# get the noise scheduler
|
# get the noise scheduler
|
||||||
|
arch = 'sd'
|
||||||
|
if self.model_config.is_pixart:
|
||||||
|
arch = 'pixart'
|
||||||
|
if self.model_config.is_flux:
|
||||||
|
arch = 'flux'
|
||||||
sampler = get_sampler(
|
sampler = get_sampler(
|
||||||
self.train_config.noise_scheduler,
|
self.train_config.noise_scheduler,
|
||||||
{
|
{
|
||||||
"prediction_type": "v_prediction" if self.model_config.is_v_pred else "epsilon",
|
"prediction_type": "v_prediction" if self.model_config.is_v_pred else "epsilon",
|
||||||
},
|
},
|
||||||
'sd' if not self.model_config.is_pixart else 'pixart'
|
arch
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.train_config.train_refiner and self.model_config.refiner_name_or_path is not None and self.network_config is None:
|
if self.train_config.train_refiner and self.model_config.refiner_name_or_path is not None and self.network_config is None:
|
||||||
|
|||||||
@@ -403,7 +403,7 @@ class TrainConfig:
|
|||||||
|
|
||||||
# diffusion feature extractor
|
# diffusion feature extractor
|
||||||
self.diffusion_feature_extractor_path = kwargs.get('diffusion_feature_extractor_path', None)
|
self.diffusion_feature_extractor_path = kwargs.get('diffusion_feature_extractor_path', None)
|
||||||
self.diffusion_feature_extractor_weight = kwargs.get('diffusion_feature_extractor_weight', 0.1)
|
self.diffusion_feature_extractor_weight = kwargs.get('diffusion_feature_extractor_weight', 1.0)
|
||||||
|
|
||||||
# optimal noise pairing
|
# optimal noise pairing
|
||||||
self.optimal_noise_pairing_samples = kwargs.get('optimal_noise_pairing_samples', 1)
|
self.optimal_noise_pairing_samples = kwargs.get('optimal_noise_pairing_samples', 1)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
|||||||
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
|
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
|
||||||
from toolkit.stable_diffusion_model import StableDiffusion
|
from toolkit.stable_diffusion_model import StableDiffusion
|
||||||
from toolkit.train_tools import get_torch_dtype
|
from toolkit.train_tools import get_torch_dtype
|
||||||
|
from toolkit.config_modules import TrainConfig
|
||||||
|
|
||||||
GuidanceType = Literal["targeted", "polarity", "targeted_polarity", "direct"]
|
GuidanceType = Literal["targeted", "polarity", "targeted_polarity", "direct"]
|
||||||
|
|
||||||
@@ -407,6 +408,7 @@ def get_guided_loss_polarity(
|
|||||||
batch: 'DataLoaderBatchDTO',
|
batch: 'DataLoaderBatchDTO',
|
||||||
noise: torch.Tensor,
|
noise: torch.Tensor,
|
||||||
sd: 'StableDiffusion',
|
sd: 'StableDiffusion',
|
||||||
|
train_config: 'TrainConfig',
|
||||||
scaler=None,
|
scaler=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
@@ -423,8 +425,22 @@ def get_guided_loss_polarity(
|
|||||||
target_neg = noise
|
target_neg = noise
|
||||||
|
|
||||||
if sd.is_flow_matching:
|
if sd.is_flow_matching:
|
||||||
# set the timesteps for flow matching as linear since we will do weighing
|
linear_timesteps = any([
|
||||||
sd.noise_scheduler.set_train_timesteps(1000, device, linear=True)
|
train_config.linear_timesteps,
|
||||||
|
train_config.linear_timesteps2,
|
||||||
|
train_config.timestep_type == 'linear',
|
||||||
|
])
|
||||||
|
|
||||||
|
timestep_type = 'linear' if linear_timesteps else None
|
||||||
|
if timestep_type is None:
|
||||||
|
timestep_type = train_config.timestep_type
|
||||||
|
|
||||||
|
sd.noise_scheduler.set_train_timesteps(
|
||||||
|
1000,
|
||||||
|
device=device,
|
||||||
|
timestep_type=timestep_type,
|
||||||
|
latents=conditional_latents
|
||||||
|
)
|
||||||
target_pos = (noise - conditional_latents).detach()
|
target_pos = (noise - conditional_latents).detach()
|
||||||
target_neg = (noise - unconditional_latents).detach()
|
target_neg = (noise - unconditional_latents).detach()
|
||||||
|
|
||||||
@@ -481,11 +497,6 @@ def get_guided_loss_polarity(
|
|||||||
|
|
||||||
loss = pred_loss + pred_neg_loss
|
loss = pred_loss + pred_neg_loss
|
||||||
|
|
||||||
# if sd.is_flow_matching:
|
|
||||||
# timestep_weight = sd.noise_scheduler.get_weights_for_timesteps(timesteps).to(loss.device, dtype=loss.dtype).detach()
|
|
||||||
# loss = loss * timestep_weight
|
|
||||||
|
|
||||||
|
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
loss = loss.mean()
|
loss = loss.mean()
|
||||||
if scaler is not None:
|
if scaler is not None:
|
||||||
@@ -609,6 +620,7 @@ def get_guidance_loss(
|
|||||||
mask_multiplier=None,
|
mask_multiplier=None,
|
||||||
prior_pred=None,
|
prior_pred=None,
|
||||||
scaler=None,
|
scaler=None,
|
||||||
|
train_config=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
# TODO add others and process individual batch items separately
|
# TODO add others and process individual batch items separately
|
||||||
@@ -641,6 +653,7 @@ def get_guidance_loss(
|
|||||||
noise,
|
noise,
|
||||||
sd,
|
sd,
|
||||||
scaler=scaler,
|
scaler=scaler,
|
||||||
|
train_config=train_config,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
elif guidance_type == "tnt":
|
elif guidance_type == "tnt":
|
||||||
|
|||||||
@@ -226,45 +226,48 @@ class DiffusionFeatureExtractor3(nn.Module):
|
|||||||
return feats_list
|
return feats_list
|
||||||
|
|
||||||
# do lpips
|
# do lpips
|
||||||
lpips_feat_list = [x.detach() for x in get_lpips_features(
|
lpips_feat_list = [x for x in get_lpips_features(
|
||||||
tensors_n1p1.to(device, dtype=torch.float32))]
|
tensors_n1p1.to(device, dtype=torch.float32))]
|
||||||
|
|
||||||
return lpips_feat_list
|
return lpips_feat_list
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
noise,
|
||||||
noise_pred,
|
noise_pred,
|
||||||
noisy_latents,
|
noisy_latents,
|
||||||
timesteps,
|
timesteps,
|
||||||
batch: DataLoaderBatchDTO,
|
batch: DataLoaderBatchDTO,
|
||||||
scheduler: CustomFlowMatchEulerDiscreteScheduler,
|
scheduler: CustomFlowMatchEulerDiscreteScheduler,
|
||||||
lpips_weight=20.0,
|
lpips_weight=1.0,
|
||||||
clip_weight=0.1,
|
clip_weight=0.1,
|
||||||
pixel_weight=1.0
|
pixel_weight=0.1
|
||||||
):
|
):
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
device = self.vae.device
|
device = self.vae.device
|
||||||
|
|
||||||
# first we step the scheduler from current timestep to the very end for a full denoise
|
# first we step the scheduler from current timestep to the very end for a full denoise
|
||||||
bs = noise_pred.shape[0]
|
# bs = noise_pred.shape[0]
|
||||||
noise_pred_chunks = torch.chunk(noise_pred, bs)
|
# noise_pred_chunks = torch.chunk(noise_pred, bs)
|
||||||
timestep_chunks = torch.chunk(timesteps, bs)
|
# timestep_chunks = torch.chunk(timesteps, bs)
|
||||||
noisy_latent_chunks = torch.chunk(noisy_latents, bs)
|
# noisy_latent_chunks = torch.chunk(noisy_latents, bs)
|
||||||
stepped_chunks = []
|
# stepped_chunks = []
|
||||||
for idx in range(bs):
|
# for idx in range(bs):
|
||||||
model_output = noise_pred_chunks[idx]
|
# model_output = noise_pred_chunks[idx]
|
||||||
timestep = timestep_chunks[idx]
|
# timestep = timestep_chunks[idx]
|
||||||
scheduler._step_index = None
|
# scheduler._step_index = None
|
||||||
scheduler._init_step_index(timestep)
|
# scheduler._init_step_index(timestep)
|
||||||
sample = noisy_latent_chunks[idx].to(torch.float32)
|
# sample = noisy_latent_chunks[idx].to(torch.float32)
|
||||||
|
|
||||||
sigma = scheduler.sigmas[scheduler.step_index]
|
# sigma = scheduler.sigmas[scheduler.step_index]
|
||||||
sigma_next = scheduler.sigmas[-1] # use last sigma for final step
|
# sigma_next = scheduler.sigmas[-1] # use last sigma for final step
|
||||||
prev_sample = sample + (sigma_next - sigma) * model_output
|
# prev_sample = sample + (sigma_next - sigma) * model_output
|
||||||
stepped_chunks.append(prev_sample)
|
# stepped_chunks.append(prev_sample)
|
||||||
|
|
||||||
stepped_latents = torch.cat(stepped_chunks, dim=0)
|
# stepped_latents = torch.cat(stepped_chunks, dim=0)
|
||||||
|
|
||||||
|
stepped_latents = noise - noise_pred
|
||||||
|
|
||||||
latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype)
|
latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype)
|
||||||
|
|
||||||
@@ -274,16 +277,18 @@ class DiffusionFeatureExtractor3(nn.Module):
|
|||||||
|
|
||||||
pred_images = (tensors_n1p1 + 1) / 2 # 0 to 1
|
pred_images = (tensors_n1p1 + 1) / 2 # 0 to 1
|
||||||
|
|
||||||
pred_clip_output = self.get_siglip_features(pred_images)
|
|
||||||
lpips_feat_list_pred = self.get_lpips_features(pred_images.float())
|
lpips_feat_list_pred = self.get_lpips_features(pred_images.float())
|
||||||
|
|
||||||
|
total_loss = 0
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
target_img = batch.tensor.to(device, dtype=dtype)
|
target_img = batch.tensor.to(device, dtype=dtype)
|
||||||
# go from -1 to 1 to 0 to 1
|
# go from -1 to 1 to 0 to 1
|
||||||
target_img = (target_img + 1) / 2
|
target_img = (target_img + 1) / 2
|
||||||
target_clip_output = self.get_siglip_features(target_img).detach()
|
|
||||||
lpips_feat_list_target = self.get_lpips_features(target_img.float())
|
lpips_feat_list_target = self.get_lpips_features(target_img.float())
|
||||||
|
target_clip_output = self.get_siglip_features(target_img).detach()
|
||||||
|
|
||||||
|
pred_clip_output = self.get_siglip_features(pred_images)
|
||||||
clip_loss = torch.nn.functional.mse_loss(
|
clip_loss = torch.nn.functional.mse_loss(
|
||||||
pred_clip_output.float(), target_clip_output.float()
|
pred_clip_output.float(), target_clip_output.float()
|
||||||
) * clip_weight
|
) * clip_weight
|
||||||
@@ -293,7 +298,7 @@ class DiffusionFeatureExtractor3(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.losses['clip_loss'] += clip_loss.item()
|
self.losses['clip_loss'] += clip_loss.item()
|
||||||
|
|
||||||
total_loss = clip_loss
|
total_loss += clip_loss
|
||||||
|
|
||||||
lpips_loss = 0
|
lpips_loss = 0
|
||||||
for idx, lpips_feat in enumerate(lpips_feat_list_pred):
|
for idx, lpips_feat in enumerate(lpips_feat_list_pred):
|
||||||
@@ -308,14 +313,14 @@ class DiffusionFeatureExtractor3(nn.Module):
|
|||||||
|
|
||||||
total_loss += lpips_loss
|
total_loss += lpips_loss
|
||||||
|
|
||||||
mse_loss = torch.nn.functional.mse_loss(
|
# mse_loss = torch.nn.functional.mse_loss(
|
||||||
stepped_latents.float(), batch.latents.float()
|
# stepped_latents.float(), batch.latents.float()
|
||||||
) * pixel_weight
|
# ) * pixel_weight
|
||||||
|
|
||||||
if 'pixel_loss' not in self.losses:
|
# if 'pixel_loss' not in self.losses:
|
||||||
self.losses['pixel_loss'] = mse_loss.item()
|
# self.losses['pixel_loss'] = mse_loss.item()
|
||||||
else:
|
# else:
|
||||||
self.losses['pixel_loss'] += mse_loss.item()
|
# self.losses['pixel_loss'] += mse_loss.item()
|
||||||
|
|
||||||
if self.step % self.log_every == 0 and self.step > 0:
|
if self.step % self.log_every == 0 and self.step > 0:
|
||||||
print(f"DFE losses:")
|
print(f"DFE losses:")
|
||||||
@@ -325,7 +330,7 @@ class DiffusionFeatureExtractor3(nn.Module):
|
|||||||
print(f" - {key}: {self.losses[key]:.3e}")
|
print(f" - {key}: {self.losses[key]:.3e}")
|
||||||
self.losses[key] = 0.0
|
self.losses[key] = 0.0
|
||||||
|
|
||||||
total_loss += mse_loss
|
# total_loss += mse_loss
|
||||||
self.step += 1
|
self.step += 1
|
||||||
|
|
||||||
return total_loss
|
return total_loss
|
||||||
|
|||||||
@@ -88,6 +88,18 @@ flux_config = {
|
|||||||
"use_dynamic_shifting": True
|
"use_dynamic_shifting": True
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sd_flow_config = {
|
||||||
|
"_class_name": "FlowMatchEulerDiscreteScheduler",
|
||||||
|
"_diffusers_version": "0.30.0.dev0",
|
||||||
|
"base_image_seq_len": 256,
|
||||||
|
"base_shift": 0.5,
|
||||||
|
"max_image_seq_len": 4096,
|
||||||
|
"max_shift": 1.15,
|
||||||
|
"num_train_timesteps": 1000,
|
||||||
|
"shift": 3.0,
|
||||||
|
"use_dynamic_shifting": False
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_sampler(
|
def get_sampler(
|
||||||
sampler: str,
|
sampler: str,
|
||||||
@@ -133,6 +145,8 @@ def get_sampler(
|
|||||||
elif sampler == "flowmatch":
|
elif sampler == "flowmatch":
|
||||||
scheduler_cls = CustomFlowMatchEulerDiscreteScheduler
|
scheduler_cls = CustomFlowMatchEulerDiscreteScheduler
|
||||||
config_to_use = copy.deepcopy(flux_config)
|
config_to_use = copy.deepcopy(flux_config)
|
||||||
|
if arch == "sd":
|
||||||
|
config_to_use = copy.deepcopy(sd_flow_config)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Sampler {sampler} not supported")
|
raise ValueError(f"Sampler {sampler} not supported")
|
||||||
|
|
||||||
|
|||||||
@@ -974,12 +974,17 @@ class StableDiffusion:
|
|||||||
"prediction_type": self.prediction_type,
|
"prediction_type": self.prediction_type,
|
||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
|
arch = 'sd'
|
||||||
|
if self.model_config.is_pixart:
|
||||||
|
arch = 'pixart'
|
||||||
|
if self.model_config.is_flux:
|
||||||
|
arch = 'flux'
|
||||||
noise_scheduler = get_sampler(
|
noise_scheduler = get_sampler(
|
||||||
sampler,
|
sampler,
|
||||||
{
|
{
|
||||||
"prediction_type": self.prediction_type,
|
"prediction_type": self.prediction_type,
|
||||||
},
|
},
|
||||||
'sd' if not self.is_pixart else 'pixart'
|
arch
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user