mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
Make a CFG version of flux pipeline
This commit is contained in:
@@ -78,8 +78,12 @@ for key, value in state_dict.items():
|
|||||||
new_key = new_key.replace('lora_up', 'lora_B')
|
new_key = new_key.replace('lora_up', 'lora_B')
|
||||||
new_key = new_key.replace('_lora', '.lora')
|
new_key = new_key.replace('_lora', '.lora')
|
||||||
new_key = new_key.replace('attn_', 'attn.')
|
new_key = new_key.replace('attn_', 'attn.')
|
||||||
|
new_key = new_key.replace('ff_', 'ff.')
|
||||||
|
new_key = new_key.replace('context_net_', 'context.net.')
|
||||||
|
new_key = new_key.replace('0_proj', '0.proj')
|
||||||
new_key = new_key.replace('norm_linear', 'norm.linear')
|
new_key = new_key.replace('norm_linear', 'norm.linear')
|
||||||
new_key = new_key.replace('norm_out_linear', 'norm_out.linear')
|
new_key = new_key.replace('norm_out_linear', 'norm_out.linear')
|
||||||
|
new_key = new_key.replace('to_out_', 'to_out.')
|
||||||
|
|
||||||
new_state_dict[new_key] = new_val.to(orig_dtype)
|
new_state_dict[new_key] = new_val.to(orig_dtype)
|
||||||
|
|
||||||
|
|||||||
@@ -370,6 +370,7 @@ class ModelConfig:
|
|||||||
self.is_flux: bool = kwargs.get('is_flux', False)
|
self.is_flux: bool = kwargs.get('is_flux', False)
|
||||||
if self.is_pixart_sigma:
|
if self.is_pixart_sigma:
|
||||||
self.is_pixart = True
|
self.is_pixart = True
|
||||||
|
self.use_flux_cfg = kwargs.get('use_flux_cfg', False)
|
||||||
self.is_ssd: bool = kwargs.get('is_ssd', False)
|
self.is_ssd: bool = kwargs.get('is_ssd', False)
|
||||||
self.is_vega: bool = kwargs.get('is_vega', False)
|
self.is_vega: bool = kwargs.get('is_vega', False)
|
||||||
self.is_v_pred: bool = kwargs.get('is_v_pred', False)
|
self.is_v_pred: bool = kwargs.get('is_v_pred', False)
|
||||||
|
|||||||
@@ -2,8 +2,11 @@ import importlib
|
|||||||
import inspect
|
import inspect
|
||||||
from typing import Union, List, Optional, Dict, Any, Tuple, Callable
|
from typing import Union, List, Optional, Dict, Any, Tuple, Callable
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline, LMSDiscreteScheduler
|
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline, LMSDiscreteScheduler, FluxPipeline
|
||||||
|
from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
|
||||||
|
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||||
# from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_k_diffusion import ModelWrapper
|
# from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_k_diffusion import ModelWrapper
|
||||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||||
@@ -1202,3 +1205,217 @@ class StableDiffusionXLRefinerPipeline(StableDiffusionXLPipeline):
|
|||||||
|
|
||||||
return StableDiffusionXLPipelineOutput(images=image)
|
return StableDiffusionXLPipelineOutput(images=image)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# TODO this is rough. Need to properly stack unconditional
|
||||||
|
class FluxWithCFGPipeline(FluxPipeline):
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, List[str]] = None,
|
||||||
|
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||||
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||||
|
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
||||||
|
height: Optional[int] = None,
|
||||||
|
width: Optional[int] = None,
|
||||||
|
num_inference_steps: int = 28,
|
||||||
|
timesteps: List[int] = None,
|
||||||
|
guidance_scale: float = 7.0,
|
||||||
|
num_images_per_prompt: Optional[int] = 1,
|
||||||
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||||
|
latents: Optional[torch.FloatTensor] = None,
|
||||||
|
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
output_type: Optional[str] = "pil",
|
||||||
|
return_dict: bool = True,
|
||||||
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||||
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||||
|
max_sequence_length: int = 512,
|
||||||
|
):
|
||||||
|
|
||||||
|
height = height or self.default_sample_size * self.vae_scale_factor
|
||||||
|
width = width or self.default_sample_size * self.vae_scale_factor
|
||||||
|
|
||||||
|
# 1. Check inputs. Raise error if not correct
|
||||||
|
self.check_inputs(
|
||||||
|
prompt,
|
||||||
|
prompt_2,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
prompt_embeds=prompt_embeds,
|
||||||
|
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||||
|
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._guidance_scale = guidance_scale
|
||||||
|
self._joint_attention_kwargs = joint_attention_kwargs
|
||||||
|
self._interrupt = False
|
||||||
|
|
||||||
|
# 2. Define call parameters
|
||||||
|
if prompt is not None and isinstance(prompt, str):
|
||||||
|
batch_size = 1
|
||||||
|
elif prompt is not None and isinstance(prompt, list):
|
||||||
|
batch_size = len(prompt)
|
||||||
|
else:
|
||||||
|
batch_size = prompt_embeds.shape[0]
|
||||||
|
|
||||||
|
device = self._execution_device
|
||||||
|
|
||||||
|
lora_scale = (
|
||||||
|
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
||||||
|
)
|
||||||
|
(
|
||||||
|
prompt_embeds,
|
||||||
|
pooled_prompt_embeds,
|
||||||
|
text_ids,
|
||||||
|
) = self.encode_prompt(
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_2=prompt_2,
|
||||||
|
prompt_embeds=prompt_embeds,
|
||||||
|
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||||
|
device=device,
|
||||||
|
num_images_per_prompt=num_images_per_prompt,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
|
lora_scale=lora_scale,
|
||||||
|
)
|
||||||
|
(
|
||||||
|
negative_prompt_embeds,
|
||||||
|
negative_pooled_prompt_embeds,
|
||||||
|
negative_text_ids,
|
||||||
|
) = self.encode_prompt(
|
||||||
|
prompt=negative_prompt,
|
||||||
|
prompt_2=negative_prompt_2,
|
||||||
|
prompt_embeds=negative_prompt_embeds,
|
||||||
|
pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||||
|
device=device,
|
||||||
|
num_images_per_prompt=num_images_per_prompt,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
|
lora_scale=lora_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Prepare latent variables
|
||||||
|
num_channels_latents = self.transformer.config.in_channels // 4
|
||||||
|
latents, latent_image_ids = self.prepare_latents(
|
||||||
|
batch_size * num_images_per_prompt,
|
||||||
|
num_channels_latents,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
prompt_embeds.dtype,
|
||||||
|
device,
|
||||||
|
generator,
|
||||||
|
latents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. Prepare timesteps
|
||||||
|
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||||
|
image_seq_len = latents.shape[1]
|
||||||
|
mu = calculate_shift(
|
||||||
|
image_seq_len,
|
||||||
|
self.scheduler.config.base_image_seq_len,
|
||||||
|
self.scheduler.config.max_image_seq_len,
|
||||||
|
self.scheduler.config.base_shift,
|
||||||
|
self.scheduler.config.max_shift,
|
||||||
|
)
|
||||||
|
timesteps, num_inference_steps = retrieve_timesteps(
|
||||||
|
self.scheduler,
|
||||||
|
num_inference_steps,
|
||||||
|
device,
|
||||||
|
timesteps,
|
||||||
|
sigmas,
|
||||||
|
mu=mu,
|
||||||
|
)
|
||||||
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||||
|
self._num_timesteps = len(timesteps)
|
||||||
|
|
||||||
|
# 6. Denoising loop
|
||||||
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||||
|
for i, t in enumerate(timesteps):
|
||||||
|
if self.interrupt:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
|
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||||
|
|
||||||
|
# handle guidance
|
||||||
|
if self.transformer.config.guidance_embeds:
|
||||||
|
guidance = torch.tensor([guidance_scale], device=device)
|
||||||
|
guidance = guidance.expand(latents.shape[0])
|
||||||
|
else:
|
||||||
|
guidance = None
|
||||||
|
|
||||||
|
noise_pred_text = self.transformer(
|
||||||
|
hidden_states=latents,
|
||||||
|
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||||
|
timestep=timestep / 1000,
|
||||||
|
guidance=guidance,
|
||||||
|
pooled_projections=pooled_prompt_embeds,
|
||||||
|
encoder_hidden_states=prompt_embeds,
|
||||||
|
txt_ids=text_ids,
|
||||||
|
img_ids=latent_image_ids,
|
||||||
|
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
# todo combine these
|
||||||
|
noise_pred_uncond = self.transformer(
|
||||||
|
hidden_states=latents,
|
||||||
|
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||||
|
timestep=timestep / 1000,
|
||||||
|
guidance=guidance,
|
||||||
|
pooled_projections=negative_pooled_prompt_embeds,
|
||||||
|
encoder_hidden_states=negative_prompt_embeds,
|
||||||
|
txt_ids=negative_text_ids,
|
||||||
|
img_ids=latent_image_ids,
|
||||||
|
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
latents_dtype = latents.dtype
|
||||||
|
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||||
|
|
||||||
|
if latents.dtype != latents_dtype:
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||||
|
latents = latents.to(latents_dtype)
|
||||||
|
|
||||||
|
if callback_on_step_end is not None:
|
||||||
|
callback_kwargs = {}
|
||||||
|
for k in callback_on_step_end_tensor_inputs:
|
||||||
|
callback_kwargs[k] = locals()[k]
|
||||||
|
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||||
|
|
||||||
|
latents = callback_outputs.pop("latents", latents)
|
||||||
|
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||||
|
|
||||||
|
# call the callback, if provided
|
||||||
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||||
|
progress_bar.update()
|
||||||
|
|
||||||
|
if XLA_AVAILABLE:
|
||||||
|
xm.mark_step()
|
||||||
|
|
||||||
|
if output_type == "latent":
|
||||||
|
image = latents
|
||||||
|
|
||||||
|
else:
|
||||||
|
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
||||||
|
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||||
|
image = self.vae.decode(latents, return_dict=False)[0]
|
||||||
|
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||||
|
|
||||||
|
# Offload all models
|
||||||
|
self.maybe_free_model_hooks()
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (image,)
|
||||||
|
|
||||||
|
return FluxPipelineOutput(images=image)
|
||||||
@@ -36,7 +36,7 @@ from toolkit.sd_device_states_presets import empty_preset
|
|||||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||||
import torch
|
import torch
|
||||||
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
|
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
|
||||||
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline
|
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline, FluxWithCFGPipeline
|
||||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
|
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
|
||||||
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, PixArtTransformer2DModel, \
|
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, PixArtTransformer2DModel, \
|
||||||
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
|
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
|
||||||
@@ -797,16 +797,29 @@ class StableDiffusion:
|
|||||||
).to(self.device_torch)
|
).to(self.device_torch)
|
||||||
pipeline.watermark = None
|
pipeline.watermark = None
|
||||||
elif self.is_flux:
|
elif self.is_flux:
|
||||||
pipeline = FluxPipeline(
|
if self.model_config.use_flux_cfg:
|
||||||
vae=self.vae,
|
pipeline = FluxWithCFGPipeline(
|
||||||
transformer=self.unet,
|
vae=self.vae,
|
||||||
text_encoder=self.text_encoder[0],
|
transformer=self.unet,
|
||||||
text_encoder_2=self.text_encoder[1],
|
text_encoder=self.text_encoder[0],
|
||||||
tokenizer=self.tokenizer[0],
|
text_encoder_2=self.text_encoder[1],
|
||||||
tokenizer_2=self.tokenizer[1],
|
tokenizer=self.tokenizer[0],
|
||||||
scheduler=noise_scheduler,
|
tokenizer_2=self.tokenizer[1],
|
||||||
**extra_args
|
scheduler=noise_scheduler,
|
||||||
)
|
**extra_args
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
pipeline = FluxPipeline(
|
||||||
|
vae=self.vae,
|
||||||
|
transformer=self.unet,
|
||||||
|
text_encoder=self.text_encoder[0],
|
||||||
|
text_encoder_2=self.text_encoder[1],
|
||||||
|
tokenizer=self.tokenizer[0],
|
||||||
|
tokenizer_2=self.tokenizer[1],
|
||||||
|
scheduler=noise_scheduler,
|
||||||
|
**extra_args
|
||||||
|
)
|
||||||
pipeline.watermark = None
|
pipeline.watermark = None
|
||||||
elif self.is_v3:
|
elif self.is_v3:
|
||||||
pipeline = Pipe(
|
pipeline = Pipe(
|
||||||
@@ -1068,18 +1081,32 @@ class StableDiffusion:
|
|||||||
**extra
|
**extra
|
||||||
).images[0]
|
).images[0]
|
||||||
elif self.is_flux:
|
elif self.is_flux:
|
||||||
img = pipeline(
|
if self.model_config.use_flux_cfg:
|
||||||
prompt_embeds=conditional_embeds.text_embeds,
|
img = pipeline(
|
||||||
pooled_prompt_embeds=conditional_embeds.pooled_embeds,
|
prompt_embeds=conditional_embeds.text_embeds,
|
||||||
# negative_prompt_embeds=unconditional_embeds.text_embeds,
|
pooled_prompt_embeds=conditional_embeds.pooled_embeds,
|
||||||
# negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds,
|
negative_prompt_embeds=unconditional_embeds.text_embeds,
|
||||||
height=gen_config.height,
|
negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds,
|
||||||
width=gen_config.width,
|
height=gen_config.height,
|
||||||
num_inference_steps=gen_config.num_inference_steps,
|
width=gen_config.width,
|
||||||
guidance_scale=gen_config.guidance_scale,
|
num_inference_steps=gen_config.num_inference_steps,
|
||||||
latents=gen_config.latents,
|
guidance_scale=gen_config.guidance_scale,
|
||||||
**extra
|
latents=gen_config.latents,
|
||||||
).images[0]
|
**extra
|
||||||
|
).images[0]
|
||||||
|
else:
|
||||||
|
img = pipeline(
|
||||||
|
prompt_embeds=conditional_embeds.text_embeds,
|
||||||
|
pooled_prompt_embeds=conditional_embeds.pooled_embeds,
|
||||||
|
# negative_prompt_embeds=unconditional_embeds.text_embeds,
|
||||||
|
# negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds,
|
||||||
|
height=gen_config.height,
|
||||||
|
width=gen_config.width,
|
||||||
|
num_inference_steps=gen_config.num_inference_steps,
|
||||||
|
guidance_scale=gen_config.guidance_scale,
|
||||||
|
latents=gen_config.latents,
|
||||||
|
**extra
|
||||||
|
).images[0]
|
||||||
elif self.is_pixart:
|
elif self.is_pixart:
|
||||||
# needs attention masks for some reason
|
# needs attention masks for some reason
|
||||||
img = pipeline(
|
img = pipeline(
|
||||||
|
|||||||
Reference in New Issue
Block a user