mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-04 20:49:58 +00:00
Added initial support for f-lite model
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
from .chroma import ChromaModel
|
||||
from .hidream import HidreamModel
|
||||
from .f_light import FLiteModel
|
||||
|
||||
AI_TOOLKIT_MODELS = [
|
||||
# put a list of models here
|
||||
ChromaModel, HidreamModel
|
||||
ChromaModel, HidreamModel, FLiteModel
|
||||
]
|
||||
|
||||
1
extensions_built_in/diffusion_models/f_light/__init__.py
Normal file
1
extensions_built_in/diffusion_models/f_light/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .f_light import FLiteModel
|
||||
295
extensions_built_in/diffusion_models/f_light/f_light.py
Normal file
295
extensions_built_in/diffusion_models/f_light/f_light.py
Normal file
@@ -0,0 +1,295 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
||||
from PIL import Image
|
||||
from toolkit.models.base_model import BaseModel
|
||||
from toolkit.basic import flush
|
||||
from diffusers import AutoencoderKL
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
|
||||
from toolkit.dequantize import patch_dequantization_on_save
|
||||
from toolkit.accelerator import unwrap_model
|
||||
from optimum.quanto import freeze, QTensor
|
||||
from toolkit.util.quantize import quantize, get_qtype
|
||||
from transformers import T5TokenizerFast, T5EncoderModel
|
||||
from .src import FLitePipeline, DiT
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
|
||||
scheduler_config = {
|
||||
"base_image_seq_len": 256,
|
||||
"base_shift": 0.5,
|
||||
"max_image_seq_len": 4096,
|
||||
"max_shift": 1.15,
|
||||
"num_train_timesteps": 1000,
|
||||
"shift": 3.0,
|
||||
"use_dynamic_shifting": True
|
||||
}
|
||||
|
||||
|
||||
class FLiteModel(BaseModel):
|
||||
arch = "f-lite"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device,
|
||||
model_config: ModelConfig,
|
||||
dtype='bf16',
|
||||
custom_pipeline=None,
|
||||
noise_scheduler=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
device,
|
||||
model_config,
|
||||
dtype,
|
||||
custom_pipeline,
|
||||
noise_scheduler,
|
||||
**kwargs
|
||||
)
|
||||
self.is_flow_matching = True
|
||||
self.is_transformer = True
|
||||
self.target_lora_modules = ['DiT']
|
||||
|
||||
# static method to get the noise scheduler
|
||||
@staticmethod
|
||||
def get_train_scheduler():
|
||||
return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
|
||||
|
||||
def get_bucket_divisibility(self):
|
||||
# return the bucket divisibility for the model
|
||||
return 16
|
||||
|
||||
def load_model(self):
|
||||
dtype = self.torch_dtype
|
||||
|
||||
# will be updated if we detect a existing checkpoint in training folder
|
||||
model_path = self.model_config.name_or_path
|
||||
|
||||
extras_path = self.model_config.extras_name_or_path
|
||||
|
||||
self.print_and_status_update("Loading transformer")
|
||||
|
||||
transformer = DiT.from_pretrained(
|
||||
model_path,
|
||||
subfolder="dit_model",
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
|
||||
transformer.to(self.quantize_device, dtype=dtype)
|
||||
|
||||
if self.model_config.quantize:
|
||||
# patch the state dict method
|
||||
patch_dequantization_on_save(transformer)
|
||||
quantization_type = get_qtype(self.model_config.qtype)
|
||||
self.print_and_status_update("Quantizing transformer")
|
||||
quantize(transformer, weights=quantization_type,
|
||||
**self.model_config.quantize_kwargs)
|
||||
freeze(transformer)
|
||||
transformer.to(self.device_torch)
|
||||
else:
|
||||
transformer.to(self.device_torch, dtype=dtype)
|
||||
|
||||
flush()
|
||||
|
||||
self.print_and_status_update("Loading T5")
|
||||
tokenizer = T5TokenizerFast.from_pretrained(
|
||||
extras_path, subfolder="tokenizer", torch_dtype=dtype
|
||||
)
|
||||
text_encoder = T5EncoderModel.from_pretrained(
|
||||
extras_path, subfolder="text_encoder", torch_dtype=dtype
|
||||
)
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize_te:
|
||||
self.print_and_status_update("Quantizing T5")
|
||||
quantize(text_encoder, weights=get_qtype(
|
||||
self.model_config.qtype))
|
||||
freeze(text_encoder)
|
||||
flush()
|
||||
|
||||
self.noise_scheduler = FLiteModel.get_train_scheduler()
|
||||
|
||||
self.print_and_status_update("Loading VAE")
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
extras_path,
|
||||
subfolder="vae",
|
||||
torch_dtype=dtype
|
||||
)
|
||||
vae = vae.to(self.device_torch, dtype=dtype)
|
||||
|
||||
self.print_and_status_update("Making pipe")
|
||||
|
||||
pipe: FLitePipeline = FLitePipeline(
|
||||
text_encoder=None,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
dit_model=None,
|
||||
)
|
||||
# for quantization, it works best to do these after making the pipe
|
||||
pipe.text_encoder = text_encoder
|
||||
pipe.dit_model = transformer
|
||||
pipe.transformer = transformer
|
||||
pipe.scheduler = self.noise_scheduler,
|
||||
|
||||
self.print_and_status_update("Preparing Model")
|
||||
|
||||
text_encoder = [pipe.text_encoder]
|
||||
tokenizer = [pipe.tokenizer]
|
||||
|
||||
pipe.transformer = pipe.transformer.to(self.device_torch)
|
||||
|
||||
flush()
|
||||
# just to make sure everything is on the right device and dtype
|
||||
text_encoder[0].to(self.device_torch)
|
||||
text_encoder[0].requires_grad_(False)
|
||||
text_encoder[0].eval()
|
||||
pipe.transformer = pipe.transformer.to(self.device_torch)
|
||||
flush()
|
||||
|
||||
# save it to the model class
|
||||
self.vae = vae
|
||||
self.text_encoder = text_encoder # list of text encoders
|
||||
self.tokenizer = tokenizer # list of tokenizers
|
||||
self.model = pipe.transformer
|
||||
self.pipeline = pipe
|
||||
self.print_and_status_update("Model Loaded")
|
||||
|
||||
def get_generation_pipeline(self):
|
||||
scheduler = FLiteModel.get_train_scheduler()
|
||||
# it has built in scheduler. Basically euler flowmatching
|
||||
pipeline = FLitePipeline(
|
||||
text_encoder=unwrap_model(self.text_encoder[0]),
|
||||
tokenizer=self.tokenizer[0],
|
||||
vae=unwrap_model(self.vae),
|
||||
dit_model=unwrap_model(self.transformer)
|
||||
)
|
||||
pipeline.transformer = pipeline.dit_model
|
||||
pipeline.scheduler = scheduler
|
||||
|
||||
return pipeline
|
||||
|
||||
def generate_single_image(
|
||||
self,
|
||||
pipeline: FLitePipeline,
|
||||
gen_config: GenerateImageConfig,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
unconditional_embeds: PromptEmbeds,
|
||||
generator: torch.Generator,
|
||||
extra: dict,
|
||||
):
|
||||
|
||||
extra['negative_prompt_embeds'] = unconditional_embeds.text_embeds
|
||||
|
||||
img = pipeline(
|
||||
prompt_embeds=conditional_embeds.text_embeds,
|
||||
negative_prompt_embeds=unconditional_embeds.text_embeds,
|
||||
height=gen_config.height,
|
||||
width=gen_config.width,
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
guidance_scale=gen_config.guidance_scale,
|
||||
latents=gen_config.latents,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
return img
|
||||
|
||||
def get_noise_prediction(
|
||||
self,
|
||||
latent_model_input: torch.Tensor,
|
||||
timestep: torch.Tensor, # 0 to 1000 scale
|
||||
text_embeddings: PromptEmbeds,
|
||||
**kwargs
|
||||
):
|
||||
cast_dtype = self.unet.dtype
|
||||
|
||||
noise_pred = self.unet(
|
||||
latent_model_input.to(
|
||||
self.device_torch, cast_dtype
|
||||
),
|
||||
text_embeddings.text_embeds.to(
|
||||
self.device_torch, cast_dtype
|
||||
),
|
||||
timestep / 1000,
|
||||
)
|
||||
|
||||
if isinstance(noise_pred, QTensor):
|
||||
noise_pred = noise_pred.dequantize()
|
||||
|
||||
return noise_pred
|
||||
|
||||
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
|
||||
if isinstance(prompt, str):
|
||||
prompts = [prompt]
|
||||
else:
|
||||
prompts = prompt
|
||||
if self.pipeline.text_encoder.device != self.device_torch:
|
||||
self.pipeline.text_encoder.to(self.device_torch)
|
||||
|
||||
prompt_embeds, negative_embeds = self.pipeline.encode_prompt(
|
||||
prompt=prompts,
|
||||
negative_prompt=None,
|
||||
device=self.text_encoder[0].device,
|
||||
dtype=self.torch_dtype,
|
||||
)
|
||||
|
||||
pe = PromptEmbeds(prompt_embeds)
|
||||
|
||||
return pe
|
||||
|
||||
def get_model_has_grad(self):
|
||||
# return from a weight if it has grad
|
||||
return False
|
||||
|
||||
def get_te_has_grad(self):
|
||||
# return from a weight if it has grad
|
||||
return False
|
||||
|
||||
def save_model(self, output_path, meta, save_dtype):
|
||||
# only save the unet
|
||||
transformer: DiT = unwrap_model(self.model)
|
||||
# diffusers
|
||||
# only save the unet
|
||||
transformer: DiT = unwrap_model(self.transformer)
|
||||
transformer.save_pretrained(
|
||||
save_directory=os.path.join(output_path, 'dit_model'),
|
||||
safe_serialization=True,
|
||||
)
|
||||
# save out meta config
|
||||
meta_path = os.path.join(output_path, 'aitk_meta.yaml')
|
||||
with open(meta_path, 'w') as f:
|
||||
yaml.dump(meta, f)
|
||||
|
||||
def get_loss_target(self, *args, **kwargs):
|
||||
noise = kwargs.get('noise')
|
||||
batch = kwargs.get('batch')
|
||||
# return (noise - batch.latents).detach()
|
||||
return (batch.latents - noise).detach()
|
||||
|
||||
def convert_lora_weights_before_save(self, state_dict):
|
||||
# currently starte with transformer. but needs to start with diffusion_model. for comfyui
|
||||
new_sd = {}
|
||||
for key, value in state_dict.items():
|
||||
new_key = key.replace("transformer.", "diffusion_model.")
|
||||
new_sd[new_key] = value
|
||||
return new_sd
|
||||
|
||||
def convert_lora_weights_before_load(self, state_dict):
|
||||
# saved as diffusion_model. but needs to be transformer. for ai-toolkit
|
||||
new_sd = {}
|
||||
for key, value in state_dict.items():
|
||||
new_key = key.replace("diffusion_model.", "transformer.")
|
||||
new_sd[new_key] = value
|
||||
return new_sd
|
||||
|
||||
def get_base_model_version(self):
|
||||
return "f-lite"
|
||||
|
||||
def get_stepped_pred(self, pred, noise):
|
||||
# just used for DFE support
|
||||
latents = pred + noise
|
||||
return latents
|
||||
@@ -0,0 +1,5 @@
|
||||
from .pipeline import FLitePipeline, FLitePipelineOutput, APGConfig
|
||||
from .model import DiT
|
||||
|
||||
|
||||
__all__ = ["FLitePipeline", "FLitePipelineOutput", "APGConfig", "DiT"]
|
||||
456
extensions_built_in/diffusion_models/f_light/src/model.py
Normal file
456
extensions_built_in/diffusion_models/f_light/src/model.py
Normal file
@@ -0,0 +1,456 @@
|
||||
# originally from https://github.com/fal-ai/f-lite/blob/main/f_lite/model.py but modified slightly
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.utils.accelerate_utils import apply_forward_hook
|
||||
from einops import rearrange
|
||||
from peft import get_peft_model_state_dict, set_peft_model_state_dict
|
||||
from torch import nn
|
||||
|
||||
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
||||
device=t.device
|
||||
)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
|
||||
return embedding
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim, eps=1e-6, trainable=False):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
if trainable:
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
else:
|
||||
self.weight = None
|
||||
|
||||
def forward(self, x):
|
||||
x_dtype = x.dtype
|
||||
x = x.float()
|
||||
norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
if self.weight is not None:
|
||||
return (x * norm * self.weight).to(dtype=x_dtype)
|
||||
else:
|
||||
return (x * norm).to(dtype=x_dtype)
|
||||
|
||||
|
||||
class QKNorm(nn.Module):
|
||||
"""Normalizing the query and the key independently, as Flux proposes"""
|
||||
|
||||
def __init__(self, dim, trainable=False):
|
||||
super().__init__()
|
||||
self.query_norm = RMSNorm(dim, trainable=trainable)
|
||||
self.key_norm = RMSNorm(dim, trainable=trainable)
|
||||
|
||||
def forward(self, q, k):
|
||||
q = self.query_norm(q)
|
||||
k = self.key_norm(k)
|
||||
return q, k
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
qkv_bias=False,
|
||||
is_self_attn=True,
|
||||
cross_attn_input_size=None,
|
||||
residual_v=False,
|
||||
dynamic_softmax_temperature=False,
|
||||
):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.is_self_attn = is_self_attn
|
||||
self.residual_v = residual_v
|
||||
self.dynamic_softmax_temperature = dynamic_softmax_temperature
|
||||
|
||||
if is_self_attn:
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
else:
|
||||
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
self.context_kv = nn.Linear(cross_attn_input_size, dim * 2, bias=qkv_bias)
|
||||
|
||||
self.proj = nn.Linear(dim, dim, bias=False)
|
||||
|
||||
if residual_v:
|
||||
self.lambda_param = nn.Parameter(torch.tensor(0.5).reshape(1))
|
||||
|
||||
self.qk_norm = QKNorm(self.head_dim)
|
||||
|
||||
def forward(self, x, context=None, v_0=None, rope=None):
|
||||
if self.is_self_attn:
|
||||
qkv = self.qkv(x)
|
||||
qkv = rearrange(qkv, "b l (k h d) -> k b h l d", k=3, h=self.num_heads)
|
||||
q, k, v = qkv.unbind(0)
|
||||
|
||||
if self.residual_v and v_0 is not None:
|
||||
v = self.lambda_param * v + (1 - self.lambda_param) * v_0
|
||||
|
||||
if rope is not None:
|
||||
# print(q.shape, rope[0].shape, rope[1].shape)
|
||||
q = apply_rotary_emb(q, rope[0], rope[1])
|
||||
k = apply_rotary_emb(k, rope[0], rope[1])
|
||||
|
||||
# https://arxiv.org/abs/2306.08645
|
||||
# https://arxiv.org/abs/2410.01104
|
||||
# ratioonale is that if tokens get larger, categorical distribution get more uniform
|
||||
# so you want to enlargen entropy.
|
||||
|
||||
token_length = q.shape[2]
|
||||
if self.dynamic_softmax_temperature:
|
||||
ratio = math.sqrt(math.log(token_length) / math.log(1040.0)) # 1024 + 16
|
||||
k = k * ratio
|
||||
q, k = self.qk_norm(q, k)
|
||||
|
||||
else:
|
||||
q = rearrange(self.q(x), "b l (h d) -> b h l d", h=self.num_heads)
|
||||
kv = rearrange(
|
||||
self.context_kv(context),
|
||||
"b l (k h d) -> k b h l d",
|
||||
k=2,
|
||||
h=self.num_heads,
|
||||
)
|
||||
k, v = kv.unbind(0)
|
||||
q, k = self.qk_norm(q, k)
|
||||
|
||||
x = F.scaled_dot_product_attention(q, k, v)
|
||||
x = rearrange(x, "b h l d -> b l (h d)")
|
||||
x = self.proj(x)
|
||||
return x, v if self.is_self_attn else None
|
||||
|
||||
|
||||
class DiTBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
cross_attn_input_size,
|
||||
num_heads,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=True,
|
||||
residual_v=False,
|
||||
dynamic_softmax_temperature=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.norm1 = RMSNorm(hidden_size, trainable=qkv_bias)
|
||||
self.self_attn = Attention(
|
||||
hidden_size,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
is_self_attn=True,
|
||||
residual_v=residual_v,
|
||||
dynamic_softmax_temperature=dynamic_softmax_temperature,
|
||||
)
|
||||
|
||||
if cross_attn_input_size is not None:
|
||||
self.norm2 = RMSNorm(hidden_size, trainable=qkv_bias)
|
||||
self.cross_attn = Attention(
|
||||
hidden_size,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
is_self_attn=False,
|
||||
cross_attn_input_size=cross_attn_input_size,
|
||||
dynamic_softmax_temperature=dynamic_softmax_temperature,
|
||||
)
|
||||
else:
|
||||
self.norm2 = None
|
||||
self.cross_attn = None
|
||||
|
||||
self.norm3 = RMSNorm(hidden_size, trainable=qkv_bias)
|
||||
mlp_hidden = int(hidden_size * mlp_ratio)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden),
|
||||
nn.GELU(),
|
||||
nn.Linear(mlp_hidden, hidden_size),
|
||||
)
|
||||
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 9 * hidden_size, bias=True))
|
||||
|
||||
self.adaLN_modulation[-1].weight.data.zero_()
|
||||
self.adaLN_modulation[-1].bias.data.zero_()
|
||||
|
||||
# @torch.compile(mode='reduce-overhead')
|
||||
def forward(self, x, context, c, v_0=None, rope=None):
|
||||
(
|
||||
shift_sa,
|
||||
scale_sa,
|
||||
gate_sa,
|
||||
shift_ca,
|
||||
scale_ca,
|
||||
gate_ca,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
) = self.adaLN_modulation(c).chunk(9, dim=1)
|
||||
|
||||
scale_sa = scale_sa[:, None, :]
|
||||
scale_ca = scale_ca[:, None, :]
|
||||
scale_mlp = scale_mlp[:, None, :]
|
||||
|
||||
shift_sa = shift_sa[:, None, :]
|
||||
shift_ca = shift_ca[:, None, :]
|
||||
shift_mlp = shift_mlp[:, None, :]
|
||||
|
||||
gate_sa = gate_sa[:, None, :]
|
||||
gate_ca = gate_ca[:, None, :]
|
||||
gate_mlp = gate_mlp[:, None, :]
|
||||
|
||||
norm_x = self.norm1(x.clone())
|
||||
norm_x = norm_x * (1 + scale_sa) + shift_sa
|
||||
attn_out, v = self.self_attn(norm_x, v_0=v_0, rope=rope)
|
||||
x = x + attn_out * gate_sa
|
||||
|
||||
if self.norm2 is not None:
|
||||
norm_x = self.norm2(x)
|
||||
norm_x = norm_x * (1 + scale_ca) + shift_ca
|
||||
x = x + self.cross_attn(norm_x, context)[0] * gate_ca
|
||||
|
||||
norm_x = self.norm3(x)
|
||||
norm_x = norm_x * (1 + scale_mlp) + shift_mlp
|
||||
x = x + self.mlp(norm_x) * gate_mlp
|
||||
|
||||
return x, v
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
def __init__(self, patch_size=16, in_channels=3, embed_dim=768):
|
||||
super().__init__()
|
||||
self.patch_proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
self.patch_size = patch_size
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
x = self.patch_proj(x)
|
||||
x = rearrange(x, "b c h w -> b (h w) c")
|
||||
return x
|
||||
|
||||
|
||||
class TwoDimRotary(torch.nn.Module):
|
||||
def __init__(self, dim, base=10000, h=256, w=256):
|
||||
super().__init__()
|
||||
self.inv_freq = torch.FloatTensor([1.0 / (base ** (i / dim)) for i in range(0, dim, 2)])
|
||||
self.h = h
|
||||
self.w = w
|
||||
|
||||
t_h = torch.arange(h, dtype=torch.float32)
|
||||
t_w = torch.arange(w, dtype=torch.float32)
|
||||
|
||||
freqs_h = torch.outer(t_h, self.inv_freq).unsqueeze(1) # h, 1, d / 2
|
||||
freqs_w = torch.outer(t_w, self.inv_freq).unsqueeze(0) # 1, w, d / 2
|
||||
freqs_h = freqs_h.repeat(1, w, 1) # h, w, d / 2
|
||||
freqs_w = freqs_w.repeat(h, 1, 1) # h, w, d / 2
|
||||
freqs_hw = torch.cat([freqs_h, freqs_w], 2) # h, w, d
|
||||
|
||||
self.register_buffer("freqs_hw_cos", freqs_hw.cos())
|
||||
self.register_buffer("freqs_hw_sin", freqs_hw.sin())
|
||||
|
||||
def forward(self, x, height_width=None, extend_with_register_tokens=0):
|
||||
if height_width is not None:
|
||||
this_h, this_w = height_width
|
||||
else:
|
||||
this_hw = x.shape[1]
|
||||
this_h, this_w = int(this_hw**0.5), int(this_hw**0.5)
|
||||
|
||||
cos = self.freqs_hw_cos[0 : this_h, 0 : this_w]
|
||||
sin = self.freqs_hw_sin[0 : this_h, 0 : this_w]
|
||||
|
||||
cos = cos.clone().reshape(this_h * this_w, -1)
|
||||
sin = sin.clone().reshape(this_h * this_w, -1)
|
||||
|
||||
# append N of zero-attn tokens
|
||||
if extend_with_register_tokens > 0:
|
||||
cos = torch.cat(
|
||||
[
|
||||
torch.ones(extend_with_register_tokens, cos.shape[1]).to(cos.device),
|
||||
cos,
|
||||
],
|
||||
0,
|
||||
)
|
||||
sin = torch.cat(
|
||||
[
|
||||
torch.zeros(extend_with_register_tokens, sin.shape[1]).to(sin.device),
|
||||
sin,
|
||||
],
|
||||
0,
|
||||
)
|
||||
|
||||
return cos[None, None, :, :], sin[None, None, :, :] # [1, 1, T + N, Attn-dim]
|
||||
|
||||
|
||||
def apply_rotary_emb(x, cos, sin):
|
||||
orig_dtype = x.dtype
|
||||
x = x.to(dtype=torch.float32)
|
||||
assert x.ndim == 4 # multihead attention
|
||||
d = x.shape[3] // 2
|
||||
x1 = x[..., :d]
|
||||
x2 = x[..., d:]
|
||||
y1 = x1 * cos + x2 * sin
|
||||
y2 = x1 * (-sin) + x2 * cos
|
||||
return torch.cat([y1, y2], 3).to(dtype=orig_dtype)
|
||||
|
||||
|
||||
class DiT(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin): # type: ignore[misc]
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=4,
|
||||
patch_size=2,
|
||||
hidden_size=1152,
|
||||
depth=28,
|
||||
num_heads=16,
|
||||
mlp_ratio=4.0,
|
||||
cross_attn_input_size=128,
|
||||
residual_v=False,
|
||||
train_bias_and_rms=True,
|
||||
use_rope=True,
|
||||
gradient_checkpoint=False,
|
||||
dynamic_softmax_temperature=False,
|
||||
rope_base=10000,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.patch_embed = PatchEmbed(patch_size, in_channels, hidden_size)
|
||||
|
||||
if use_rope:
|
||||
self.rope = TwoDimRotary(hidden_size // (2 * num_heads), base=rope_base, h=512, w=512)
|
||||
else:
|
||||
self.positional_embedding = nn.Parameter(torch.zeros(1, 2048, hidden_size))
|
||||
|
||||
self.register_tokens = nn.Parameter(torch.randn(1, 16, hidden_size))
|
||||
|
||||
self.time_embed = nn.Sequential(
|
||||
nn.Linear(hidden_size, 4 * hidden_size),
|
||||
nn.SiLU(),
|
||||
nn.Linear(4 * hidden_size, hidden_size),
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
DiTBlock(
|
||||
hidden_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
cross_attn_input_size=cross_attn_input_size,
|
||||
residual_v=residual_v,
|
||||
qkv_bias=train_bias_and_rms,
|
||||
dynamic_softmax_temperature=dynamic_softmax_temperature,
|
||||
)
|
||||
for _ in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
||||
|
||||
self.final_norm = RMSNorm(hidden_size, trainable=train_bias_and_rms)
|
||||
self.final_proj = nn.Linear(hidden_size, patch_size * patch_size * in_channels)
|
||||
nn.init.zeros_(self.final_modulation[-1].weight)
|
||||
nn.init.zeros_(self.final_modulation[-1].bias)
|
||||
nn.init.zeros_(self.final_proj.weight)
|
||||
nn.init.zeros_(self.final_proj.bias)
|
||||
self.paramstatus = {}
|
||||
for n, p in self.named_parameters():
|
||||
self.paramstatus[n] = {
|
||||
"shape": p.shape,
|
||||
"requires_grad": p.requires_grad,
|
||||
}
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def save_lora_weights(self, save_directory):
|
||||
"""Save LoRA weights to a file"""
|
||||
lora_state_dict = get_peft_model_state_dict(self)
|
||||
torch.save(lora_state_dict, f"{save_directory}/lora_weights.pt")
|
||||
|
||||
def load_lora_weights(self, load_directory):
|
||||
"""Load LoRA weights from a file"""
|
||||
lora_state_dict = torch.load(f"{load_directory}/lora_weights.pt")
|
||||
set_peft_model_state_dict(self, lora_state_dict)
|
||||
|
||||
@apply_forward_hook
|
||||
def forward(self, x, context, timesteps):
|
||||
b, c, h, w = x.shape
|
||||
x = self.patch_embed(x) # b, T, d
|
||||
|
||||
x = torch.cat([self.register_tokens.repeat(b, 1, 1), x], 1) # b, T + N, d
|
||||
|
||||
if self.config.use_rope:
|
||||
cos, sin = self.rope(
|
||||
x,
|
||||
extend_with_register_tokens=16,
|
||||
height_width=(h // self.config.patch_size, w // self.config.patch_size),
|
||||
)
|
||||
else:
|
||||
x = x + self.positional_embedding.repeat(b, 1, 1)[:, : x.shape[1], :]
|
||||
cos, sin = None, None
|
||||
|
||||
t_emb = timestep_embedding(timesteps * 1000, self.config.hidden_size).to(x.device, dtype=x.dtype)
|
||||
t_emb = self.time_embed(t_emb)
|
||||
|
||||
v_0 = None
|
||||
|
||||
for _idx, block in enumerate(self.blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
x, v = self._gradient_checkpointing_func(
|
||||
block,
|
||||
x,
|
||||
context,
|
||||
t_emb,
|
||||
v_0,
|
||||
(cos, sin)
|
||||
)
|
||||
else:
|
||||
x, v = block(x, context, t_emb, v_0, (cos, sin))
|
||||
if v_0 is None:
|
||||
v_0 = v
|
||||
|
||||
x = x[:, 16:, :]
|
||||
final_shift, final_scale = self.final_modulation(t_emb).chunk(2, dim=1)
|
||||
x = self.final_norm(x)
|
||||
x = x * (1 + final_scale[:, None, :]) + final_shift[:, None, :]
|
||||
x = self.final_proj(x)
|
||||
|
||||
x = rearrange(
|
||||
x,
|
||||
"b (h w) (p1 p2 c) -> b c (h p1) (w p2)",
|
||||
h=h // self.config.patch_size,
|
||||
w=w // self.config.patch_size,
|
||||
p1=self.config.patch_size,
|
||||
p2=self.config.patch_size,
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = DiT(
|
||||
in_channels=4,
|
||||
patch_size=2,
|
||||
hidden_size=1152,
|
||||
depth=28,
|
||||
num_heads=16,
|
||||
mlp_ratio=4.0,
|
||||
cross_attn_input_size=128,
|
||||
residual_v=False,
|
||||
train_bias_and_rms=True,
|
||||
use_rope=True,
|
||||
).cuda()
|
||||
print(
|
||||
model(
|
||||
torch.randn(1, 4, 64, 64).cuda(),
|
||||
torch.randn(1, 37, 128).cuda(),
|
||||
torch.tensor([1.0]).cuda(),
|
||||
)
|
||||
)
|
||||
308
extensions_built_in/diffusion_models/f_light/src/pipeline.py
Normal file
308
extensions_built_in/diffusion_models/f_light/src/pipeline.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# originally from https://github.com/fal-ai/f-lite/blob/main/f_lite/pipeline.py but modified slightly
|
||||
import logging
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import AutoencoderKL, DiffusionPipeline
|
||||
from diffusers.utils import BaseOutput
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from PIL import Image
|
||||
from torch import FloatTensor
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import T5EncoderModel, T5TokenizerFast
|
||||
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class APGConfig:
|
||||
"""APG (Augmented Parallel Guidance) configuration"""
|
||||
|
||||
enabled: bool = True
|
||||
orthogonal_threshold: float = 0.03
|
||||
|
||||
|
||||
@dataclass
|
||||
class FLitePipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for FLitePipeline pipeline.
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||
"""
|
||||
|
||||
images: Union[List[Image.Image], np.ndarray]
|
||||
|
||||
|
||||
class FLitePipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using F-Lite model.
|
||||
This model inherits from [`DiffusionPipeline`].
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->dit_model->vae"
|
||||
|
||||
dit_model: torch.nn.Module
|
||||
vae: AutoencoderKL
|
||||
text_encoder: T5EncoderModel
|
||||
tokenizer: T5TokenizerFast
|
||||
_progress_bar_config: Dict[str, Any]
|
||||
|
||||
def __init__(
|
||||
self, dit_model: torch.nn.Module, vae: AutoencoderKL, text_encoder: T5EncoderModel, tokenizer: T5TokenizerFast
|
||||
):
|
||||
super().__init__()
|
||||
# Register all modules for the pipeline
|
||||
# Access DiffusionPipeline's register_modules directly to avoid mypy error
|
||||
DiffusionPipeline.register_modules(
|
||||
self, dit_model=dit_model, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer
|
||||
)
|
||||
|
||||
# Move models to channels last for better performance
|
||||
# AutoencoderKL inherits from torch.nn.Module which has these methods
|
||||
if hasattr(self.vae, "to"):
|
||||
self.vae.to(memory_format=torch.channels_last)
|
||||
if hasattr(self.vae, "requires_grad_"):
|
||||
self.vae.requires_grad_(False)
|
||||
if hasattr(self.text_encoder, "requires_grad_"):
|
||||
self.text_encoder.requires_grad_(False)
|
||||
|
||||
# Constants
|
||||
self.vae_scale_factor = 8
|
||||
self.return_index = -8 # T5 hidden state index to use
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
"""Enable VAE slicing for memory efficiency."""
|
||||
if hasattr(self.vae, "enable_slicing"):
|
||||
self.vae.enable_slicing()
|
||||
|
||||
def enable_vae_tiling(self):
|
||||
"""Enable VAE tiling for memory efficiency."""
|
||||
if hasattr(self.vae, "enable_tiling"):
|
||||
self.vae.enable_tiling()
|
||||
|
||||
def set_progress_bar_config(self, **kwargs):
|
||||
"""Set progress bar configuration."""
|
||||
self._progress_bar_config = kwargs
|
||||
|
||||
def progress_bar(self, iterable=None, **kwargs):
|
||||
"""Create progress bar for iterations."""
|
||||
self._progress_bar_config = getattr(self, "_progress_bar_config", None) or {}
|
||||
config = {**self._progress_bar_config, **kwargs}
|
||||
return tqdm(iterable, **config)
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
max_sequence_length: int = 512,
|
||||
return_index: int = -8,
|
||||
) -> Tuple[FloatTensor, FloatTensor]:
|
||||
"""Encodes the prompt and negative prompt."""
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
device = device or self.text_encoder.device
|
||||
# Text encoder forward pass
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids.to(device)
|
||||
prompt_embeds = self.text_encoder(text_input_ids, return_dict=True, output_hidden_states=True)
|
||||
prompt_embeds_tensor = prompt_embeds.hidden_states[return_index]
|
||||
if return_index != -1:
|
||||
prompt_embeds_tensor = self.text_encoder.encoder.final_layer_norm(prompt_embeds_tensor)
|
||||
prompt_embeds_tensor = self.text_encoder.encoder.dropout(prompt_embeds_tensor)
|
||||
|
||||
dtype = dtype or next(self.text_encoder.parameters()).dtype
|
||||
prompt_embeds_tensor = prompt_embeds_tensor.to(dtype=dtype, device=device)
|
||||
|
||||
# Handle negative prompts
|
||||
if negative_prompt is None:
|
||||
negative_embeds = torch.zeros_like(prompt_embeds_tensor)
|
||||
else:
|
||||
if isinstance(negative_prompt, str):
|
||||
negative_prompt = [negative_prompt]
|
||||
negative_result = self.encode_prompt(
|
||||
prompt=negative_prompt, device=device, dtype=dtype, return_index=return_index
|
||||
)
|
||||
negative_embeds = negative_result[0]
|
||||
|
||||
# Explicitly cast both tensors to FloatTensor for mypy
|
||||
from typing import cast
|
||||
|
||||
prompt_tensor = cast(FloatTensor, prompt_embeds_tensor.to(dtype=dtype))
|
||||
negative_tensor = cast(FloatTensor, negative_embeds.to(dtype=dtype))
|
||||
return (prompt_tensor, negative_tensor)
|
||||
|
||||
def to(self, torch_device=None, torch_dtype=None, silence_dtype_warnings=False):
|
||||
"""Move pipeline components to specified device and dtype."""
|
||||
if hasattr(self, "vae"):
|
||||
self.vae.to(device=torch_device, dtype=torch_dtype)
|
||||
if hasattr(self, "text_encoder"):
|
||||
self.text_encoder.to(device=torch_device, dtype=torch_dtype)
|
||||
if hasattr(self, "dit_model"):
|
||||
self.dit_model.to(device=torch_device, dtype=torch_dtype)
|
||||
return self
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]]=None,
|
||||
prompt_embeds: Optional[FloatTensor] = None,
|
||||
height: Optional[int] = 1024,
|
||||
width: Optional[int] = 1024,
|
||||
num_inference_steps: int = 30,
|
||||
guidance_scale: float = 6.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_embeds: Optional[FloatTensor] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
alpha: Optional[float] = None,
|
||||
apg_config: Optional[APGConfig] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Generate images from text prompt."""
|
||||
# Ensure height and width are not None for calculation
|
||||
if height is None:
|
||||
height = 1024
|
||||
if width is None:
|
||||
width = 1024
|
||||
|
||||
dtype = dtype or next(self.dit_model.parameters()).dtype
|
||||
apg_config = apg_config or APGConfig(enabled=False)
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 2. Encode prompts
|
||||
prompt_batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
batch_size = prompt_batch_size * num_images_per_prompt
|
||||
|
||||
if prompt_embeds is None or negative_prompt_embeds is None:
|
||||
prompt_embeds, negative_embeds = self.encode_prompt(
|
||||
prompt=prompt, negative_prompt=negative_prompt, device=self.text_encoder.device, dtype=dtype,
|
||||
return_index=self.return_index,
|
||||
)
|
||||
else:
|
||||
negative_embeds = negative_prompt_embeds
|
||||
|
||||
# Repeat embeddings for num_images_per_prompt
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
negative_embeds = negative_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
# 3. Initialize latents
|
||||
latent_height = height // self.vae_scale_factor
|
||||
latent_width = width // self.vae_scale_factor
|
||||
|
||||
if isinstance(generator, list):
|
||||
if len(generator) != batch_size:
|
||||
raise ValueError(f"Got {len(generator)} generators for {batch_size} samples")
|
||||
|
||||
latents = randn_tensor((batch_size, 16, latent_height, latent_width), generator=generator, device=device, dtype=dtype)
|
||||
acc_latents = latents.clone()
|
||||
|
||||
# 4. Calculate alpha if not provided
|
||||
if alpha is None:
|
||||
image_token_size = latent_height * latent_width
|
||||
alpha = 2 * math.sqrt(image_token_size / (64 * 64))
|
||||
|
||||
# 6. Sampling loop
|
||||
self.dit_model.eval()
|
||||
|
||||
# Check if guidance is needed
|
||||
do_classifier_free_guidance = guidance_scale >= 1.0
|
||||
|
||||
for i in self.progress_bar(range(num_inference_steps, 0, -1)):
|
||||
# Calculate timesteps
|
||||
t = i / num_inference_steps
|
||||
t_next = (i - 1) / num_inference_steps
|
||||
# Scale timesteps according to alpha
|
||||
t = t * alpha / (1 + (alpha - 1) * t)
|
||||
t_next = t_next * alpha / (1 + (alpha - 1) * t_next)
|
||||
dt = t - t_next
|
||||
|
||||
# Create tensor with proper device
|
||||
t_tensor = torch.tensor([t] * batch_size, device=device, dtype=dtype)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# Duplicate latents for both conditional and unconditional inputs
|
||||
latents_input = torch.cat([latents] * 2)
|
||||
# Concatenate negative and positive prompt embeddings
|
||||
context_input = torch.cat([negative_embeds, prompt_embeds])
|
||||
# Duplicate timesteps for the batch
|
||||
t_input = torch.cat([t_tensor] * 2)
|
||||
|
||||
# Get model predictions in a single pass
|
||||
model_outputs = self.dit_model(latents_input, context_input, t_input)
|
||||
|
||||
# Split outputs back into unconditional and conditional predictions
|
||||
uncond_output, cond_output = model_outputs.chunk(2)
|
||||
|
||||
if apg_config.enabled:
|
||||
# Augmented Parallel Guidance
|
||||
dy = cond_output
|
||||
dd = cond_output - uncond_output
|
||||
# Find parallel direction
|
||||
parallel_direction = (dy * dd).sum() / (dy * dy).sum() * dy
|
||||
orthogonal_direction = dd - parallel_direction
|
||||
# Scale orthogonal component
|
||||
orthogonal_std = orthogonal_direction.std()
|
||||
orthogonal_scale = min(1, apg_config.orthogonal_threshold / orthogonal_std)
|
||||
orthogonal_direction = orthogonal_direction * orthogonal_scale
|
||||
model_output = dy + (guidance_scale - 1) * orthogonal_direction
|
||||
else:
|
||||
# Standard classifier-free guidance
|
||||
model_output = uncond_output + guidance_scale * (cond_output - uncond_output)
|
||||
else:
|
||||
# If no guidance needed, just run the model normally
|
||||
model_output = self.dit_model(latents, prompt_embeds, t_tensor)
|
||||
|
||||
# Update latents
|
||||
acc_latents = acc_latents + dt * model_output.to(device)
|
||||
latents = acc_latents.clone()
|
||||
|
||||
# 7. Decode latents
|
||||
# These checks handle the case where mypy doesn't recognize these attributes
|
||||
scaling_factor = getattr(self.vae.config, "scaling_factor", 0.18215) if hasattr(self.vae, "config") else 0.18215
|
||||
shift_factor = getattr(self.vae.config, "shift_factor", 0) if hasattr(self.vae, "config") else 0
|
||||
|
||||
latents = latents / scaling_factor + shift_factor
|
||||
|
||||
vae_dtype = self.vae.dtype if hasattr(self.vae, "dtype") else dtype
|
||||
decoded_images = self.vae.decode(latents.to(vae_dtype)).sample if hasattr(self.vae, "decode") else latents
|
||||
|
||||
# Offload all models
|
||||
try:
|
||||
self.maybe_free_model_hooks()
|
||||
except AttributeError as e:
|
||||
if "OptimizedModule" in str(e):
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"Encountered 'OptimizedModule' error when offloading models. "
|
||||
"This issue might be fixed in the future by: "
|
||||
"https://github.com/huggingface/diffusers/pull/10730"
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
# 8. Post-process images
|
||||
images = (decoded_images / 2 + 0.5).clamp(0, 1)
|
||||
# Convert to PIL Images
|
||||
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu()
|
||||
pil_images = [Image.fromarray(img.permute(1, 2, 0).numpy()) for img in images]
|
||||
|
||||
return FLitePipelineOutput(
|
||||
images=pil_images,
|
||||
)
|
||||
@@ -249,7 +249,8 @@ class DiffusionFeatureExtractor3(nn.Module):
|
||||
# lpips_weight=1.0,
|
||||
lpips_weight=10.0,
|
||||
clip_weight=0.1,
|
||||
pixel_weight=0.1
|
||||
pixel_weight=0.1,
|
||||
model=None
|
||||
):
|
||||
dtype = torch.bfloat16
|
||||
device = self.vae.device
|
||||
@@ -274,7 +275,10 @@ class DiffusionFeatureExtractor3(nn.Module):
|
||||
|
||||
# stepped_latents = torch.cat(stepped_chunks, dim=0)
|
||||
|
||||
stepped_latents = noise - noise_pred
|
||||
if model is not None and hasattr(model, 'get_stepped_pred'):
|
||||
stepped_latents = model.get_stepped_pred(noise_pred, noise)
|
||||
else:
|
||||
stepped_latents = noise - noise_pred
|
||||
|
||||
latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype)
|
||||
|
||||
|
||||
@@ -2283,6 +2283,7 @@ class StableDiffusion:
|
||||
bleed_latents: torch.FloatTensor = None,
|
||||
is_input_scaled=False,
|
||||
return_first_prediction=False,
|
||||
bypass_guidance_embedding=False,
|
||||
**kwargs,
|
||||
):
|
||||
timesteps_to_run = self.noise_scheduler.timesteps[start_timesteps:total_timesteps]
|
||||
@@ -2299,6 +2300,7 @@ class StableDiffusion:
|
||||
add_time_ids=add_time_ids,
|
||||
is_input_scaled=is_input_scaled,
|
||||
return_conditional_pred=True,
|
||||
bypass_guidance_embedding=bypass_guidance_embedding,
|
||||
**kwargs,
|
||||
)
|
||||
# some schedulers need to run separately, so do that. (euler for example)
|
||||
|
||||
@@ -145,7 +145,7 @@ if TYPE_CHECKING:
|
||||
def concat_prompt_embeddings(
|
||||
unconditional: 'PromptEmbeds',
|
||||
conditional: 'PromptEmbeds',
|
||||
n_imgs: int,
|
||||
n_imgs: int=0,
|
||||
):
|
||||
from toolkit.stable_diffusion_model import PromptEmbeds
|
||||
text_embeds = torch.cat(
|
||||
|
||||
Reference in New Issue
Block a user