mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 22:49:48 +00:00
Adapter work. Bug fixes. Auto adjust LR when resuming optimizer.
This commit is contained in:
@@ -863,6 +863,7 @@ class StableDiffusion:
|
||||
unconditional_embeddings: Union[PromptEmbeds, None] = None,
|
||||
is_input_scaled=False,
|
||||
detach_unconditional=False,
|
||||
rescale_cfg=None,
|
||||
**kwargs,
|
||||
):
|
||||
# get the embeddings
|
||||
@@ -1111,6 +1112,21 @@ class StableDiffusion:
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
if rescale_cfg is not None and rescale_cfg != guidance_scale:
|
||||
with torch.no_grad():
|
||||
# do cfg at the target rescale so we can match it
|
||||
target_pred_mean_std = noise_pred_uncond + rescale_cfg * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
target_mean = target_pred_mean_std.mean([1, 2, 3], keepdim=True).detach()
|
||||
target_std = target_pred_mean_std.std([1, 2, 3], keepdim=True).detach()
|
||||
|
||||
pred_mean = noise_pred.mean([1, 2, 3], keepdim=True).detach()
|
||||
pred_std = noise_pred.std([1, 2, 3], keepdim=True).detach()
|
||||
|
||||
# match the mean and std
|
||||
noise_pred = (noise_pred - pred_mean) / pred_std
|
||||
noise_pred = (noise_pred * target_std) + target_mean
|
||||
|
||||
# https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775
|
||||
if guidance_rescale > 0.0:
|
||||
|
||||
Reference in New Issue
Block a user