mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Small fixed for DFE, polar guidance, and other things
This commit is contained in:
@@ -6,6 +6,7 @@ from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
from toolkit.config_modules import TrainConfig
|
||||
|
||||
GuidanceType = Literal["targeted", "polarity", "targeted_polarity", "direct"]
|
||||
|
||||
@@ -407,6 +408,7 @@ def get_guided_loss_polarity(
|
||||
batch: 'DataLoaderBatchDTO',
|
||||
noise: torch.Tensor,
|
||||
sd: 'StableDiffusion',
|
||||
train_config: 'TrainConfig',
|
||||
scaler=None,
|
||||
**kwargs
|
||||
):
|
||||
@@ -423,8 +425,22 @@ def get_guided_loss_polarity(
|
||||
target_neg = noise
|
||||
|
||||
if sd.is_flow_matching:
|
||||
# set the timesteps for flow matching as linear since we will do weighing
|
||||
sd.noise_scheduler.set_train_timesteps(1000, device, linear=True)
|
||||
linear_timesteps = any([
|
||||
train_config.linear_timesteps,
|
||||
train_config.linear_timesteps2,
|
||||
train_config.timestep_type == 'linear',
|
||||
])
|
||||
|
||||
timestep_type = 'linear' if linear_timesteps else None
|
||||
if timestep_type is None:
|
||||
timestep_type = train_config.timestep_type
|
||||
|
||||
sd.noise_scheduler.set_train_timesteps(
|
||||
1000,
|
||||
device=device,
|
||||
timestep_type=timestep_type,
|
||||
latents=conditional_latents
|
||||
)
|
||||
target_pos = (noise - conditional_latents).detach()
|
||||
target_neg = (noise - unconditional_latents).detach()
|
||||
|
||||
@@ -481,11 +497,6 @@ def get_guided_loss_polarity(
|
||||
|
||||
loss = pred_loss + pred_neg_loss
|
||||
|
||||
# if sd.is_flow_matching:
|
||||
# timestep_weight = sd.noise_scheduler.get_weights_for_timesteps(timesteps).to(loss.device, dtype=loss.dtype).detach()
|
||||
# loss = loss * timestep_weight
|
||||
|
||||
|
||||
loss = loss.mean([1, 2, 3])
|
||||
loss = loss.mean()
|
||||
if scaler is not None:
|
||||
@@ -609,6 +620,7 @@ def get_guidance_loss(
|
||||
mask_multiplier=None,
|
||||
prior_pred=None,
|
||||
scaler=None,
|
||||
train_config=None,
|
||||
**kwargs
|
||||
):
|
||||
# TODO add others and process individual batch items separately
|
||||
@@ -641,6 +653,7 @@ def get_guidance_loss(
|
||||
noise,
|
||||
sd,
|
||||
scaler=scaler,
|
||||
train_config=train_config,
|
||||
**kwargs
|
||||
)
|
||||
elif guidance_type == "tnt":
|
||||
|
||||
Reference in New Issue
Block a user