mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
git status
This commit is contained in:
@@ -30,8 +30,11 @@ if TYPE_CHECKING:
|
|||||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from diffusers import QwenImageEditPlusPipeline
|
from .qwen_image_pipelines import QwenImageEditPlusCustomPipeline
|
||||||
from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit_plus import CONDITION_IMAGE_SIZE, VAE_IMAGE_SIZE
|
from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit_plus import (
|
||||||
|
CONDITION_IMAGE_SIZE,
|
||||||
|
VAE_IMAGE_SIZE,
|
||||||
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Diffusers is out of date. Update diffusers to the latest version by doing 'pip uninstall diffusers' and then 'pip install -r requirements.txt'"
|
"Diffusers is out of date. Update diffusers to the latest version by doing 'pip uninstall diffusers' and then 'pip install -r requirements.txt'"
|
||||||
@@ -41,7 +44,7 @@ except ImportError:
|
|||||||
class QwenImageEditPlusModel(QwenImageModel):
|
class QwenImageEditPlusModel(QwenImageModel):
|
||||||
arch = "qwen_image_edit_plus"
|
arch = "qwen_image_edit_plus"
|
||||||
_qwen_image_keep_visual = True
|
_qwen_image_keep_visual = True
|
||||||
_qwen_pipeline = QwenImageEditPlusPipeline
|
_qwen_pipeline = QwenImageEditPlusCustomPipeline
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -72,7 +75,7 @@ class QwenImageEditPlusModel(QwenImageModel):
|
|||||||
def get_generation_pipeline(self):
|
def get_generation_pipeline(self):
|
||||||
scheduler = QwenImageModel.get_train_scheduler()
|
scheduler = QwenImageModel.get_train_scheduler()
|
||||||
|
|
||||||
pipeline: QwenImageEditPlusPipeline = QwenImageEditPlusPipeline(
|
pipeline: QwenImageEditPlusCustomPipeline = QwenImageEditPlusCustomPipeline(
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
text_encoder=unwrap_model(self.text_encoder[0]),
|
text_encoder=unwrap_model(self.text_encoder[0]),
|
||||||
tokenizer=self.tokenizer[0],
|
tokenizer=self.tokenizer[0],
|
||||||
@@ -87,7 +90,7 @@ class QwenImageEditPlusModel(QwenImageModel):
|
|||||||
|
|
||||||
def generate_single_image(
|
def generate_single_image(
|
||||||
self,
|
self,
|
||||||
pipeline: QwenImageEditPlusPipeline,
|
pipeline: QwenImageEditPlusCustomPipeline,
|
||||||
gen_config: GenerateImageConfig,
|
gen_config: GenerateImageConfig,
|
||||||
conditional_embeds: PromptEmbeds,
|
conditional_embeds: PromptEmbeds,
|
||||||
unconditional_embeds: PromptEmbeds,
|
unconditional_embeds: PromptEmbeds,
|
||||||
@@ -147,6 +150,7 @@ class QwenImageEditPlusModel(QwenImageModel):
|
|||||||
latents=gen_config.latents,
|
latents=gen_config.latents,
|
||||||
generator=generator,
|
generator=generator,
|
||||||
callback_on_step_end=callback_on_step_end,
|
callback_on_step_end=callback_on_step_end,
|
||||||
|
do_cfg_norm=gen_config.do_cfg_norm,
|
||||||
**extra,
|
**extra,
|
||||||
).images[0]
|
).images[0]
|
||||||
return img
|
return img
|
||||||
@@ -223,7 +227,9 @@ class QwenImageEditPlusModel(QwenImageModel):
|
|||||||
|
|
||||||
if batch.control_tensor_list is not None:
|
if batch.control_tensor_list is not None:
|
||||||
if len(batch.control_tensor_list) != batch_size:
|
if len(batch.control_tensor_list) != batch_size:
|
||||||
raise ValueError("Control tensor list length does not match batch size")
|
raise ValueError(
|
||||||
|
"Control tensor list length does not match batch size"
|
||||||
|
)
|
||||||
b = 0
|
b = 0
|
||||||
for control_tensor_list in batch.control_tensor_list:
|
for control_tensor_list in batch.control_tensor_list:
|
||||||
# control tensor list is a list of tensors for this batch item
|
# control tensor list is a list of tensors for this batch item
|
||||||
@@ -231,7 +237,9 @@ class QwenImageEditPlusModel(QwenImageModel):
|
|||||||
# pack control
|
# pack control
|
||||||
for control_img in control_tensor_list:
|
for control_img in control_tensor_list:
|
||||||
# control images are 0 - 1 scale, shape (1, ch, height, width)
|
# control images are 0 - 1 scale, shape (1, ch, height, width)
|
||||||
control_img = control_img.to(self.device_torch, dtype=self.torch_dtype)
|
control_img = control_img.to(
|
||||||
|
self.device_torch, dtype=self.torch_dtype
|
||||||
|
)
|
||||||
# if it is only 3 dim, add batch dim
|
# if it is only 3 dim, add batch dim
|
||||||
if len(control_img.shape) == 3:
|
if len(control_img.shape) == 3:
|
||||||
control_img = control_img.unsqueeze(0)
|
control_img = control_img.unsqueeze(0)
|
||||||
@@ -255,25 +263,41 @@ class QwenImageEditPlusModel(QwenImageModel):
|
|||||||
dtype=self.torch_dtype,
|
dtype=self.torch_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
clb, cl_num_channels_latents, cl_height, cl_width = control_latent.shape
|
clb, cl_num_channels_latents, cl_height, cl_width = (
|
||||||
|
control_latent.shape
|
||||||
|
)
|
||||||
|
|
||||||
control = control_latent.view(
|
control = control_latent.view(
|
||||||
1, cl_num_channels_latents, cl_height // 2, 2, cl_width // 2, 2
|
1,
|
||||||
|
cl_num_channels_latents,
|
||||||
|
cl_height // 2,
|
||||||
|
2,
|
||||||
|
cl_width // 2,
|
||||||
|
2,
|
||||||
)
|
)
|
||||||
control = control.permute(0, 2, 4, 1, 3, 5)
|
control = control.permute(0, 2, 4, 1, 3, 5)
|
||||||
control = control.reshape(
|
control = control.reshape(
|
||||||
1, (cl_height // 2) * (cl_width // 2), num_channels_latents * 4
|
1,
|
||||||
|
(cl_height // 2) * (cl_width // 2),
|
||||||
|
num_channels_latents * 4,
|
||||||
)
|
)
|
||||||
|
|
||||||
img_shapes[b].append((1, cl_height // 2, cl_width // 2))
|
img_shapes[b].append((1, cl_height // 2, cl_width // 2))
|
||||||
controls.append(control)
|
controls.append(control)
|
||||||
|
|
||||||
# stack controls on dim 1
|
# stack controls on dim 1
|
||||||
control = torch.cat(controls, dim=1).to(packed_latents_list[b].device, dtype=packed_latents_list[b].dtype)
|
control = torch.cat(controls, dim=1).to(
|
||||||
|
packed_latents_list[b].device,
|
||||||
|
dtype=packed_latents_list[b].dtype,
|
||||||
|
)
|
||||||
# concat with latents
|
# concat with latents
|
||||||
packed_latents_with_control = torch.cat([packed_latents_list[b], control], dim=1)
|
packed_latents_with_control = torch.cat(
|
||||||
|
[packed_latents_list[b], control], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
packed_latents_with_controls_list.append(packed_latents_with_control)
|
packed_latents_with_controls_list.append(
|
||||||
|
packed_latents_with_control
|
||||||
|
)
|
||||||
|
|
||||||
b += 1
|
b += 1
|
||||||
|
|
||||||
@@ -289,7 +313,9 @@ class QwenImageEditPlusModel(QwenImageModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
noise_pred = self.transformer(
|
noise_pred = self.transformer(
|
||||||
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype).detach(),
|
hidden_states=latent_model_input.to(
|
||||||
|
self.device_torch, self.torch_dtype
|
||||||
|
).detach(),
|
||||||
timestep=(timestep / 1000).detach(),
|
timestep=(timestep / 1000).detach(),
|
||||||
guidance=None,
|
guidance=None,
|
||||||
encoder_hidden_states=enc_hs.detach(),
|
encoder_hidden_states=enc_hs.detach(),
|
||||||
|
|||||||
@@ -0,0 +1,354 @@
|
|||||||
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
try:
|
||||||
|
from diffusers import QwenImageEditPlusPipeline
|
||||||
|
from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit_plus import (
|
||||||
|
CONDITION_IMAGE_SIZE,
|
||||||
|
VAE_IMAGE_SIZE,
|
||||||
|
XLA_AVAILABLE,
|
||||||
|
logger,
|
||||||
|
calculate_dimensions,
|
||||||
|
calculate_shift,
|
||||||
|
retrieve_timesteps,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Diffusers is out of date. Update diffusers to the latest version by doing 'pip uninstall diffusers' and then 'pip install -r requirements.txt'"
|
||||||
|
)
|
||||||
|
|
||||||
|
from diffusers.image_processor import PipelineImageInput
|
||||||
|
from diffusers.pipelines.qwenimage.pipeline_output import QwenImagePipelineOutput
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageEditPlusCustomPipeline(QwenImageEditPlusPipeline):
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
image: Optional[PipelineImageInput] = None,
|
||||||
|
prompt: Union[str, List[str]] = None,
|
||||||
|
negative_prompt: Union[str, List[str]] = None,
|
||||||
|
true_cfg_scale: float = 4.0,
|
||||||
|
height: Optional[int] = None,
|
||||||
|
width: Optional[int] = None,
|
||||||
|
num_inference_steps: int = 50,
|
||||||
|
sigmas: Optional[List[float]] = None,
|
||||||
|
guidance_scale: Optional[float] = None,
|
||||||
|
num_images_per_prompt: int = 1,
|
||||||
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||||
|
latents: Optional[torch.Tensor] = None,
|
||||||
|
prompt_embeds: Optional[torch.Tensor] = None,
|
||||||
|
prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||||
|
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||||
|
negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||||
|
output_type: Optional[str] = "pil",
|
||||||
|
return_dict: bool = True,
|
||||||
|
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||||
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||||
|
max_sequence_length: int = 512,
|
||||||
|
do_cfg_norm: bool = False,
|
||||||
|
):
|
||||||
|
image_size = image[-1].size if isinstance(image, list) else image.size
|
||||||
|
calculated_width, calculated_height = calculate_dimensions(
|
||||||
|
1024 * 1024, image_size[0] / image_size[1]
|
||||||
|
)
|
||||||
|
height = height or calculated_height
|
||||||
|
width = width or calculated_width
|
||||||
|
|
||||||
|
multiple_of = self.vae_scale_factor * 2
|
||||||
|
width = width // multiple_of * multiple_of
|
||||||
|
height = height // multiple_of * multiple_of
|
||||||
|
|
||||||
|
# 1. Check inputs. Raise error if not correct
|
||||||
|
self.check_inputs(
|
||||||
|
prompt,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
prompt_embeds=prompt_embeds,
|
||||||
|
negative_prompt_embeds=negative_prompt_embeds,
|
||||||
|
prompt_embeds_mask=prompt_embeds_mask,
|
||||||
|
negative_prompt_embeds_mask=negative_prompt_embeds_mask,
|
||||||
|
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._guidance_scale = guidance_scale
|
||||||
|
self._attention_kwargs = attention_kwargs
|
||||||
|
self._current_timestep = None
|
||||||
|
self._interrupt = False
|
||||||
|
|
||||||
|
# 2. Define call parameters
|
||||||
|
if prompt is not None and isinstance(prompt, str):
|
||||||
|
batch_size = 1
|
||||||
|
elif prompt is not None and isinstance(prompt, list):
|
||||||
|
batch_size = len(prompt)
|
||||||
|
else:
|
||||||
|
batch_size = prompt_embeds.shape[0]
|
||||||
|
|
||||||
|
device = self._execution_device
|
||||||
|
# 3. Preprocess image
|
||||||
|
if image is not None and not (
|
||||||
|
isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels
|
||||||
|
):
|
||||||
|
if not isinstance(image, list):
|
||||||
|
image = [image]
|
||||||
|
condition_image_sizes = []
|
||||||
|
condition_images = []
|
||||||
|
vae_image_sizes = []
|
||||||
|
vae_images = []
|
||||||
|
for img in image:
|
||||||
|
image_width, image_height = img.size
|
||||||
|
condition_width, condition_height = calculate_dimensions(
|
||||||
|
CONDITION_IMAGE_SIZE, image_width / image_height
|
||||||
|
)
|
||||||
|
vae_width, vae_height = calculate_dimensions(
|
||||||
|
VAE_IMAGE_SIZE, image_width / image_height
|
||||||
|
)
|
||||||
|
condition_image_sizes.append((condition_width, condition_height))
|
||||||
|
vae_image_sizes.append((vae_width, vae_height))
|
||||||
|
condition_images.append(
|
||||||
|
self.image_processor.resize(img, condition_height, condition_width)
|
||||||
|
)
|
||||||
|
vae_images.append(
|
||||||
|
self.image_processor.preprocess(
|
||||||
|
img, vae_height, vae_width
|
||||||
|
).unsqueeze(2)
|
||||||
|
)
|
||||||
|
|
||||||
|
has_neg_prompt = negative_prompt is not None or (
|
||||||
|
negative_prompt_embeds is not None
|
||||||
|
and negative_prompt_embeds_mask is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
if true_cfg_scale > 1 and not has_neg_prompt:
|
||||||
|
logger.warning(
|
||||||
|
f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
|
||||||
|
)
|
||||||
|
elif true_cfg_scale <= 1 and has_neg_prompt:
|
||||||
|
logger.warning(
|
||||||
|
" negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
|
||||||
|
)
|
||||||
|
|
||||||
|
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
|
||||||
|
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
|
||||||
|
image=condition_images,
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_embeds=prompt_embeds,
|
||||||
|
prompt_embeds_mask=prompt_embeds_mask,
|
||||||
|
device=device,
|
||||||
|
num_images_per_prompt=num_images_per_prompt,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
|
)
|
||||||
|
if do_true_cfg:
|
||||||
|
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
|
||||||
|
image=condition_images,
|
||||||
|
prompt=negative_prompt,
|
||||||
|
prompt_embeds=negative_prompt_embeds,
|
||||||
|
prompt_embeds_mask=negative_prompt_embeds_mask,
|
||||||
|
device=device,
|
||||||
|
num_images_per_prompt=num_images_per_prompt,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Prepare latent variables
|
||||||
|
num_channels_latents = self.transformer.config.in_channels // 4
|
||||||
|
latents, image_latents = self.prepare_latents(
|
||||||
|
vae_images,
|
||||||
|
batch_size * num_images_per_prompt,
|
||||||
|
num_channels_latents,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
prompt_embeds.dtype,
|
||||||
|
device,
|
||||||
|
generator,
|
||||||
|
latents,
|
||||||
|
)
|
||||||
|
img_shapes = [
|
||||||
|
[
|
||||||
|
(
|
||||||
|
1,
|
||||||
|
height // self.vae_scale_factor // 2,
|
||||||
|
width // self.vae_scale_factor // 2,
|
||||||
|
),
|
||||||
|
*[
|
||||||
|
(
|
||||||
|
1,
|
||||||
|
vae_height // self.vae_scale_factor // 2,
|
||||||
|
vae_width // self.vae_scale_factor // 2,
|
||||||
|
)
|
||||||
|
for vae_width, vae_height in vae_image_sizes
|
||||||
|
],
|
||||||
|
]
|
||||||
|
] * batch_size
|
||||||
|
|
||||||
|
# 5. Prepare timesteps
|
||||||
|
sigmas = (
|
||||||
|
np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||||
|
if sigmas is None
|
||||||
|
else sigmas
|
||||||
|
)
|
||||||
|
image_seq_len = latents.shape[1]
|
||||||
|
mu = calculate_shift(
|
||||||
|
image_seq_len,
|
||||||
|
self.scheduler.config.get("base_image_seq_len", 256),
|
||||||
|
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||||
|
self.scheduler.config.get("base_shift", 0.5),
|
||||||
|
self.scheduler.config.get("max_shift", 1.15),
|
||||||
|
)
|
||||||
|
timesteps, num_inference_steps = retrieve_timesteps(
|
||||||
|
self.scheduler,
|
||||||
|
num_inference_steps,
|
||||||
|
device,
|
||||||
|
sigmas=sigmas,
|
||||||
|
mu=mu,
|
||||||
|
)
|
||||||
|
num_warmup_steps = max(
|
||||||
|
len(timesteps) - num_inference_steps * self.scheduler.order, 0
|
||||||
|
)
|
||||||
|
self._num_timesteps = len(timesteps)
|
||||||
|
|
||||||
|
# handle guidance
|
||||||
|
if self.transformer.config.guidance_embeds and guidance_scale is None:
|
||||||
|
raise ValueError("guidance_scale is required for guidance-distilled model.")
|
||||||
|
elif self.transformer.config.guidance_embeds:
|
||||||
|
guidance = torch.full(
|
||||||
|
[1], guidance_scale, device=device, dtype=torch.float32
|
||||||
|
)
|
||||||
|
guidance = guidance.expand(latents.shape[0])
|
||||||
|
elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
|
||||||
|
logger.warning(
|
||||||
|
f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
|
||||||
|
)
|
||||||
|
guidance = None
|
||||||
|
elif not self.transformer.config.guidance_embeds and guidance_scale is None:
|
||||||
|
guidance = None
|
||||||
|
|
||||||
|
if self.attention_kwargs is None:
|
||||||
|
self._attention_kwargs = {}
|
||||||
|
|
||||||
|
txt_seq_lens = (
|
||||||
|
prompt_embeds_mask.sum(dim=1).tolist()
|
||||||
|
if prompt_embeds_mask is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
negative_txt_seq_lens = (
|
||||||
|
negative_prompt_embeds_mask.sum(dim=1).tolist()
|
||||||
|
if negative_prompt_embeds_mask is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. Denoising loop
|
||||||
|
self.scheduler.set_begin_index(0)
|
||||||
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||||
|
for i, t in enumerate(timesteps):
|
||||||
|
if self.interrupt:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self._current_timestep = t
|
||||||
|
|
||||||
|
latent_model_input = latents
|
||||||
|
if image_latents is not None:
|
||||||
|
latent_model_input = torch.cat([latents, image_latents], dim=1)
|
||||||
|
|
||||||
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
|
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||||
|
with self.transformer.cache_context("cond"):
|
||||||
|
noise_pred = self.transformer(
|
||||||
|
hidden_states=latent_model_input,
|
||||||
|
timestep=timestep / 1000,
|
||||||
|
guidance=guidance,
|
||||||
|
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||||
|
encoder_hidden_states=prompt_embeds,
|
||||||
|
img_shapes=img_shapes,
|
||||||
|
txt_seq_lens=txt_seq_lens,
|
||||||
|
attention_kwargs=self.attention_kwargs,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
noise_pred = noise_pred[:, : latents.size(1)]
|
||||||
|
|
||||||
|
if do_true_cfg:
|
||||||
|
with self.transformer.cache_context("uncond"):
|
||||||
|
neg_noise_pred = self.transformer(
|
||||||
|
hidden_states=latent_model_input,
|
||||||
|
timestep=timestep / 1000,
|
||||||
|
guidance=guidance,
|
||||||
|
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||||
|
encoder_hidden_states=negative_prompt_embeds,
|
||||||
|
img_shapes=img_shapes,
|
||||||
|
txt_seq_lens=negative_txt_seq_lens,
|
||||||
|
attention_kwargs=self.attention_kwargs,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
|
||||||
|
comb_pred = neg_noise_pred + true_cfg_scale * (
|
||||||
|
noise_pred - neg_noise_pred
|
||||||
|
)
|
||||||
|
|
||||||
|
if do_cfg_norm:
|
||||||
|
# the official code does this, but I find it hurts more often than it helps, leaving it optional but off by default
|
||||||
|
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
|
||||||
|
noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
|
||||||
|
noise_pred = comb_pred * (cond_norm / noise_norm)
|
||||||
|
else:
|
||||||
|
noise_pred = comb_pred
|
||||||
|
|
||||||
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
latents_dtype = latents.dtype
|
||||||
|
latents = self.scheduler.step(
|
||||||
|
noise_pred, t, latents, return_dict=False
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
if latents.dtype != latents_dtype:
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||||
|
latents = latents.to(latents_dtype)
|
||||||
|
|
||||||
|
if callback_on_step_end is not None:
|
||||||
|
callback_kwargs = {}
|
||||||
|
for k in callback_on_step_end_tensor_inputs:
|
||||||
|
callback_kwargs[k] = locals()[k]
|
||||||
|
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||||
|
|
||||||
|
latents = callback_outputs.pop("latents", latents)
|
||||||
|
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||||
|
|
||||||
|
# call the callback, if provided
|
||||||
|
if i == len(timesteps) - 1 or (
|
||||||
|
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
||||||
|
):
|
||||||
|
progress_bar.update()
|
||||||
|
|
||||||
|
if XLA_AVAILABLE:
|
||||||
|
xm.mark_step()
|
||||||
|
|
||||||
|
self._current_timestep = None
|
||||||
|
if output_type == "latent":
|
||||||
|
image = latents
|
||||||
|
else:
|
||||||
|
latents = self._unpack_latents(
|
||||||
|
latents, height, width, self.vae_scale_factor
|
||||||
|
)
|
||||||
|
latents = latents.to(self.vae.dtype)
|
||||||
|
latents_mean = (
|
||||||
|
torch.tensor(self.vae.config.latents_mean)
|
||||||
|
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||||
|
.to(latents.device, latents.dtype)
|
||||||
|
)
|
||||||
|
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(
|
||||||
|
1, self.vae.config.z_dim, 1, 1, 1
|
||||||
|
).to(latents.device, latents.dtype)
|
||||||
|
latents = latents / latents_std + latents_mean
|
||||||
|
image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
|
||||||
|
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||||
|
|
||||||
|
# Offload all models
|
||||||
|
self.maybe_free_model_hooks()
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (image,)
|
||||||
|
|
||||||
|
return QwenImagePipelineOutput(images=image)
|
||||||
@@ -348,6 +348,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
ctrl_img_1=sample_item.ctrl_img_1,
|
ctrl_img_1=sample_item.ctrl_img_1,
|
||||||
ctrl_img_2=sample_item.ctrl_img_2,
|
ctrl_img_2=sample_item.ctrl_img_2,
|
||||||
ctrl_img_3=sample_item.ctrl_img_3,
|
ctrl_img_3=sample_item.ctrl_img_3,
|
||||||
|
do_cfg_norm=sample_config.do_cfg_norm,
|
||||||
**extra_args
|
**extra_args
|
||||||
))
|
))
|
||||||
|
|
||||||
|
|||||||
@@ -70,6 +70,8 @@ class SampleItem:
|
|||||||
print(f"Invalid network_multiplier {self.network_multiplier}, defaulting to 1.0")
|
print(f"Invalid network_multiplier {self.network_multiplier}, defaulting to 1.0")
|
||||||
self.network_multiplier = 1.0
|
self.network_multiplier = 1.0
|
||||||
|
|
||||||
|
# only for models that support it, (qwen image edit 2509 for now)
|
||||||
|
self.do_cfg_norm: bool = kwargs.get('do_cfg_norm', False)
|
||||||
|
|
||||||
class SampleConfig:
|
class SampleConfig:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
@@ -104,6 +106,8 @@ class SampleConfig:
|
|||||||
]
|
]
|
||||||
raw_samples = kwargs.get('samples', default_samples_kwargs)
|
raw_samples = kwargs.get('samples', default_samples_kwargs)
|
||||||
self.samples = [SampleItem(self, **item) for item in raw_samples]
|
self.samples = [SampleItem(self, **item) for item in raw_samples]
|
||||||
|
# only for models that support it, (qwen image edit 2509 for now)
|
||||||
|
self.do_cfg_norm: bool = kwargs.get('do_cfg_norm', False)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prompts(self):
|
def prompts(self):
|
||||||
@@ -993,7 +997,8 @@ class GenerateImageConfig:
|
|||||||
ctrl_img_3: Optional[str] = None, # third control image for multi control model
|
ctrl_img_3: Optional[str] = None, # third control image for multi control model
|
||||||
num_frames: int = 1,
|
num_frames: int = 1,
|
||||||
fps: int = 15,
|
fps: int = 15,
|
||||||
ctrl_idx: int = 0
|
ctrl_idx: int = 0,
|
||||||
|
do_cfg_norm: bool = False,
|
||||||
):
|
):
|
||||||
self.width: int = width
|
self.width: int = width
|
||||||
self.height: int = height
|
self.height: int = height
|
||||||
@@ -1064,6 +1069,8 @@ class GenerateImageConfig:
|
|||||||
|
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|
||||||
|
self.do_cfg_norm: bool = do_cfg_norm
|
||||||
|
|
||||||
def set_gen_time(self, gen_time: int = None):
|
def set_gen_time(self, gen_time: int = None):
|
||||||
if gen_time is not None:
|
if gen_time is not None:
|
||||||
self.gen_time = gen_time
|
self.gen_time = gen_time
|
||||||
|
|||||||
1
toolkit/memory_management/__init__.py
Normal file
1
toolkit/memory_management/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .manager import MemoryManager
|
||||||
12
toolkit/memory_management/manager.py
Normal file
12
toolkit/memory_management/manager.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from toolkit.models.base_model import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryManager:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: "BaseModel",
|
||||||
|
):
|
||||||
|
self.model: "BaseModel" = model
|
||||||
@@ -41,6 +41,7 @@ from torchvision.transforms import functional as TF
|
|||||||
from toolkit.accelerator import get_accelerator, unwrap_model
|
from toolkit.accelerator import get_accelerator, unwrap_model
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
from toolkit.print import print_acc
|
from toolkit.print import print_acc
|
||||||
|
from toolkit.memory_management import MemoryManager
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from toolkit.lora_special import LoRASpecialNetwork
|
from toolkit.lora_special import LoRASpecialNetwork
|
||||||
@@ -186,6 +187,8 @@ class BaseModel:
|
|||||||
# do not resize control images
|
# do not resize control images
|
||||||
self.use_raw_control_images = False
|
self.use_raw_control_images = False
|
||||||
|
|
||||||
|
self.memory_manager = MemoryManager(self)
|
||||||
|
|
||||||
# properties for old arch for backwards compatibility
|
# properties for old arch for backwards compatibility
|
||||||
@property
|
@property
|
||||||
def unet(self):
|
def unet(self):
|
||||||
|
|||||||
@@ -70,6 +70,7 @@ from typing import TYPE_CHECKING
|
|||||||
from toolkit.print import print_acc
|
from toolkit.print import print_acc
|
||||||
from diffusers import FluxFillPipeline
|
from diffusers import FluxFillPipeline
|
||||||
from transformers import AutoModel, AutoTokenizer, Gemma2Model, Qwen2Model, LlamaModel
|
from transformers import AutoModel, AutoTokenizer, Gemma2Model, Qwen2Model, LlamaModel
|
||||||
|
from toolkit.memory_management import MemoryManager
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from toolkit.lora_special import LoRASpecialNetwork
|
from toolkit.lora_special import LoRASpecialNetwork
|
||||||
@@ -224,6 +225,8 @@ class StableDiffusion:
|
|||||||
# do not resize control images
|
# do not resize control images
|
||||||
self.use_raw_control_images = False
|
self.use_raw_control_images = False
|
||||||
|
|
||||||
|
self.memory_manager = MemoryManager(self)
|
||||||
|
|
||||||
# properties for old arch for backwards compatibility
|
# properties for old arch for backwards compatibility
|
||||||
@property
|
@property
|
||||||
def is_xl(self):
|
def is_xl(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user