mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Some work on sd3 training. Not working
This commit is contained in:
@@ -29,6 +29,7 @@ import gc
|
||||
import torch
|
||||
from jobs.process import BaseSDTrainProcess
|
||||
from torchvision import transforms
|
||||
import math
|
||||
|
||||
|
||||
|
||||
@@ -366,7 +367,29 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
loss_per_element = (weighing.float() * (denoised_latents.float() - target.float()) ** 2)
|
||||
loss = loss_per_element
|
||||
else:
|
||||
if self.train_config.loss_type == "mae":
|
||||
# handle flow matching ref https://github.com/huggingface/diffusers/blob/ec068f9b5bf7c65f93125ec889e0ff1792a00da1/examples/dreambooth/train_dreambooth_lora_sd3.py#L1485C17-L1495C100
|
||||
if self.sd.is_v3:
|
||||
target = noisy_latents.detach()
|
||||
bsz = pred.shape[0]
|
||||
# todo implement others
|
||||
# weighing_scheme =
|
||||
# 3 just do mode for now?
|
||||
# if args.weighting_scheme == "sigma_sqrt":
|
||||
sigmas = self.sd.noise_scheduler.get_sigmas(timesteps, pred.ndim, dtype, self.device_torch)
|
||||
weighting = (sigmas ** -2.0).float()
|
||||
# elif args.weighting_scheme == "logit_normal":
|
||||
# # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
||||
# u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device=accelerator.device)
|
||||
# weighting = torch.nn.functional.sigmoid(u)
|
||||
# elif args.weighting_scheme == "mode":
|
||||
# mode_scale = 1.29
|
||||
# See sec 3.1 in the SD3 paper (20).
|
||||
# u = torch.rand(size=(bsz,), device=pred.device)
|
||||
# weighting = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
||||
|
||||
loss = (weighting.float() * (pred.float() - target.float()) ** 2).reshape(target.shape[0], -1)
|
||||
|
||||
elif self.train_config.loss_type == "mae":
|
||||
loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none")
|
||||
else:
|
||||
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
|
||||
|
||||
@@ -1244,6 +1244,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
conv_alpha=self.network_config.conv_alpha,
|
||||
is_sdxl=self.model_config.is_xl or self.model_config.is_ssd,
|
||||
is_v2=self.model_config.is_v2,
|
||||
is_v3=self.model_config.is_v3,
|
||||
is_ssd=self.model_config.is_ssd,
|
||||
is_vega=self.model_config.is_vega,
|
||||
dropout=self.network_config.dropout,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
torch
|
||||
torchvision
|
||||
safetensors
|
||||
diffusers==0.26.3
|
||||
diffusers
|
||||
transformers
|
||||
lycoris-lora==1.8.3
|
||||
flatten_json
|
||||
|
||||
@@ -336,6 +336,7 @@ class ModelConfig:
|
||||
self.is_xl: bool = kwargs.get('is_xl', False)
|
||||
self.is_pixart: bool = kwargs.get('is_pixart', False)
|
||||
self.is_pixart_sigma: bool = kwargs.get('is_pixart', False)
|
||||
self.is_v3: bool = kwargs.get('is_v3', False)
|
||||
if self.is_pixart_sigma:
|
||||
self.is_pixart = True
|
||||
self.is_ssd: bool = kwargs.get('is_ssd', False)
|
||||
|
||||
@@ -1353,6 +1353,8 @@ class LatentCachingMixin:
|
||||
file_item.latent_space_version = self.sd.model_config.latent_space_version
|
||||
elif self.sd.is_xl:
|
||||
file_item.latent_space_version = 'sdxl'
|
||||
elif self.sd.is_v3:
|
||||
file_item.latent_space_version = 'sd3'
|
||||
else:
|
||||
file_item.latent_space_version = 'sd1'
|
||||
file_item.is_caching_to_disk = to_disk
|
||||
|
||||
@@ -152,6 +152,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
train_unet: Optional[bool] = True,
|
||||
is_sdxl=False,
|
||||
is_v2=False,
|
||||
is_v3=False,
|
||||
is_pixart: bool = False,
|
||||
use_bias: bool = False,
|
||||
is_lorm: bool = False,
|
||||
@@ -200,6 +201,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
self.multiplier = multiplier
|
||||
self.is_sdxl = is_sdxl
|
||||
self.is_v2 = is_v2
|
||||
self.is_v3 = is_v3
|
||||
self.is_pixart = is_pixart
|
||||
self.network_type = network_type
|
||||
if self.network_type.lower() == "dora":
|
||||
@@ -233,7 +235,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
target_replace_modules: List[torch.nn.Module],
|
||||
) -> List[LoRAModule]:
|
||||
unet_prefix = self.LORA_PREFIX_UNET
|
||||
if is_pixart:
|
||||
if is_pixart or is_v3:
|
||||
unet_prefix = f"lora_transformer"
|
||||
|
||||
prefix = (
|
||||
@@ -346,6 +348,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
|
||||
target_modules += target_conv_modules
|
||||
|
||||
if is_v3:
|
||||
target_modules = ["SD3Transformer2DModel"]
|
||||
|
||||
if train_unet:
|
||||
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
||||
else:
|
||||
|
||||
@@ -48,7 +48,7 @@ class LoRAGenerator(torch.nn.Module):
|
||||
head_size: int = 512,
|
||||
num_mlp_layers: int = 1,
|
||||
output_size: int = 768,
|
||||
dropout: float = 0.5
|
||||
dropout: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
@@ -131,8 +131,12 @@ class InstantLoRAMidModule(torch.nn.Module):
|
||||
x_chunk = x_chunks[i]
|
||||
# reshape
|
||||
weight_chunk = weight_chunk.view(self.down_shape)
|
||||
# run a simple lenear layer with the down weight
|
||||
x_chunk = x_chunk @ weight_chunk.T
|
||||
# check if is conv or linear
|
||||
if len(weight_chunk.shape) == 4:
|
||||
x_chunk = nn.functional.conv2d(x_chunk, weight_chunk)
|
||||
else:
|
||||
# run a simple linear layer with the down weight
|
||||
x_chunk = x_chunk @ weight_chunk.T
|
||||
x_out.append(x_chunk)
|
||||
x = torch.cat(x_out, dim=0)
|
||||
return x
|
||||
@@ -158,8 +162,12 @@ class InstantLoRAMidModule(torch.nn.Module):
|
||||
x_chunk = x_chunks[i]
|
||||
# reshape
|
||||
weight_chunk = weight_chunk.view(self.up_shape)
|
||||
# run a simple lenear layer with the down weight
|
||||
x_chunk = x_chunk @ weight_chunk.T
|
||||
# check if is conv or linear
|
||||
if len(weight_chunk.shape) == 4:
|
||||
x_chunk = nn.functional.conv2d(x_chunk, weight_chunk)
|
||||
else:
|
||||
# run a simple linear layer with the down weight
|
||||
x_chunk = x_chunk @ weight_chunk.T
|
||||
x_out.append(x_chunk)
|
||||
x = torch.cat(x_out, dim=0)
|
||||
return x
|
||||
|
||||
@@ -4,7 +4,6 @@ import torch
|
||||
import sys
|
||||
|
||||
from PIL import Image
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
||||
from torch.nn import Parameter
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
|
||||
@@ -13,9 +13,12 @@ from diffusers import (
|
||||
HeunDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
LCMScheduler
|
||||
LCMScheduler,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
)
|
||||
|
||||
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
|
||||
|
||||
from k_diffusion.external import CompVisDenoiser
|
||||
|
||||
from toolkit.samplers.custom_lcm_scheduler import CustomLCMScheduler
|
||||
@@ -112,6 +115,15 @@ def get_sampler(
|
||||
scheduler_cls = LCMScheduler
|
||||
elif sampler == "custom_lcm":
|
||||
scheduler_cls = CustomLCMScheduler
|
||||
elif sampler == "flowmatch":
|
||||
scheduler_cls = CustomFlowMatchEulerDiscreteScheduler
|
||||
config_to_use = {
|
||||
"_class_name": "FlowMatchEulerDiscreteScheduler",
|
||||
"_diffusers_version": "0.29.0.dev0",
|
||||
"num_train_timesteps": 1000,
|
||||
"shift": 3.0
|
||||
}
|
||||
|
||||
|
||||
config = copy.deepcopy(config_to_use)
|
||||
config.update(sched_init_args)
|
||||
|
||||
32
toolkit/samplers/custom_flowmatch_sampler.py
Normal file
32
toolkit/samplers/custom_flowmatch_sampler.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from typing import Union
|
||||
|
||||
from diffusers import FlowMatchEulerDiscreteScheduler
|
||||
import torch
|
||||
|
||||
class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
|
||||
def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor:
|
||||
sigmas = self.sigmas.to(device=device, dtype=dtype)
|
||||
schedule_timesteps = self.timesteps.to(device)
|
||||
timesteps = timesteps.to(device)
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < n_dim:
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
|
||||
return sigma
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
n_dim = original_samples.ndim
|
||||
sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device)
|
||||
noisy_model_input = sigmas * noise + (1.0 - sigmas) * original_samples
|
||||
return noisy_model_input
|
||||
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
return sample
|
||||
@@ -40,13 +40,13 @@ from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffu
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
|
||||
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, \
|
||||
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
|
||||
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline
|
||||
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline
|
||||
import diffusers
|
||||
from diffusers import \
|
||||
AutoencoderKL, \
|
||||
UNet2DConditionModel
|
||||
from diffusers import PixArtAlphaPipeline, DPMSolverMultistepScheduler
|
||||
from transformers import T5EncoderModel
|
||||
from transformers import T5EncoderModel, BitsAndBytesConfig
|
||||
from toolkit.util.pixart_sigma_patch import pixart_sigma_init_patched_inputs, PixArtSigmaPipeline
|
||||
|
||||
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
|
||||
@@ -147,6 +147,7 @@ class StableDiffusion:
|
||||
self.is_xl = model_config.is_xl
|
||||
self.is_v2 = model_config.is_v2
|
||||
self.is_ssd = model_config.is_ssd
|
||||
self.is_v3 = model_config.is_v3
|
||||
self.is_vega = model_config.is_vega
|
||||
self.is_pixart = model_config.is_pixart
|
||||
|
||||
@@ -236,6 +237,64 @@ class StableDiffusion:
|
||||
te1_state_dict['text_projection.weight'] = replacement_weight.to(self.device_torch, dtype=dtype)
|
||||
flush()
|
||||
print("Injecting alt weights")
|
||||
elif self.model_config.is_v3:
|
||||
if self.custom_pipeline is not None:
|
||||
pipln = self.custom_pipeline
|
||||
else:
|
||||
pipln = StableDiffusion3Pipeline
|
||||
|
||||
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
|
||||
model_id = "stabilityai/stable-diffusion-3-medium"
|
||||
text_encoder3 = T5EncoderModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="text_encoder_3",
|
||||
# quantization_config=quantization_config,
|
||||
revision="refs/pr/26",
|
||||
device_map="cuda"
|
||||
)
|
||||
|
||||
# see if path exists
|
||||
if not os.path.exists(model_path) or os.path.isdir(model_path):
|
||||
try:
|
||||
# try to load with default diffusers
|
||||
pipe = pipln.from_pretrained(
|
||||
model_path,
|
||||
dtype=dtype,
|
||||
device=self.device_torch,
|
||||
text_encoder_3=text_encoder3,
|
||||
# variant="fp16",
|
||||
use_safetensors=True,
|
||||
revision="refs/pr/26",
|
||||
repo_type="model",
|
||||
ignore_patterns=["*.md", "*..gitattributes"],
|
||||
**load_args
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error loading from pretrained: {e}")
|
||||
raise e
|
||||
|
||||
else:
|
||||
pipe = pipln.from_single_file(
|
||||
model_path,
|
||||
device=self.device_torch,
|
||||
torch_dtype=self.torch_dtype,
|
||||
text_encoder_3=text_encoder3,
|
||||
)
|
||||
|
||||
flush()
|
||||
|
||||
text_encoders = [pipe.text_encoder, pipe.text_encoder_2, pipe.text_encoder_3]
|
||||
tokenizer = [pipe.tokenizer, pipe.tokenizer_2, pipe.tokenizer_3]
|
||||
# replace the to function with a no-op since it throws an error instead of a warning
|
||||
# text_encoders[2].to = lambda *args, **kwargs: None
|
||||
for text_encoder in text_encoders:
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder.eval()
|
||||
text_encoder = text_encoders
|
||||
|
||||
|
||||
elif self.model_config.is_pixart:
|
||||
te_kwargs = {}
|
||||
# handle quantization of TE
|
||||
@@ -361,8 +420,8 @@ class StableDiffusion:
|
||||
# add hacks to unet to help training
|
||||
# pipe.unet = prepare_unet_for_training(pipe.unet)
|
||||
|
||||
if self.is_pixart:
|
||||
# pixart doesnt use a unet
|
||||
if self.is_pixart or self.is_v3:
|
||||
# pixart and sd3 dont use a unet
|
||||
self.unet = pipe.transformer
|
||||
else:
|
||||
self.unet: 'UNet2DConditionModel' = pipe.unet
|
||||
@@ -487,6 +546,8 @@ class StableDiffusion:
|
||||
Pipe = StableDiffusionKDiffusionXLPipeline
|
||||
elif self.is_xl:
|
||||
Pipe = StableDiffusionXLPipeline
|
||||
elif self.is_v3:
|
||||
Pipe = StableDiffusion3Pipeline
|
||||
else:
|
||||
Pipe = StableDiffusionPipeline
|
||||
|
||||
@@ -515,15 +576,30 @@ class StableDiffusion:
|
||||
if self.is_xl:
|
||||
pipeline = Pipe(
|
||||
vae=self.vae,
|
||||
unet=self.unet,
|
||||
transformer=self.unet,
|
||||
text_encoder=self.text_encoder[0],
|
||||
text_encoder_2=self.text_encoder[1],
|
||||
text_encoder_3=self.text_encoder[2],
|
||||
tokenizer=self.tokenizer[0],
|
||||
tokenizer_2=self.tokenizer[1],
|
||||
tokenizer_3=self.tokenizer[2],
|
||||
scheduler=noise_scheduler,
|
||||
**extra_args
|
||||
).to(self.device_torch)
|
||||
pipeline.watermark = None
|
||||
elif self.is_v3:
|
||||
pipeline = Pipe(
|
||||
vae=self.vae,
|
||||
transformer=self.unet,
|
||||
text_encoder=self.text_encoder[0],
|
||||
text_encoder_2=self.text_encoder[1],
|
||||
text_encoder_3=self.text_encoder[2],
|
||||
tokenizer=self.tokenizer[0],
|
||||
tokenizer_2=self.tokenizer[1],
|
||||
tokenizer_3=self.tokenizer[2],
|
||||
scheduler=noise_scheduler,
|
||||
**extra_args
|
||||
)
|
||||
elif self.is_pixart:
|
||||
pipeline = PixArtAlphaPipeline(
|
||||
vae=self.vae,
|
||||
@@ -576,7 +652,7 @@ class StableDiffusion:
|
||||
if self.network is not None:
|
||||
start_multiplier = self.network.multiplier
|
||||
|
||||
pipeline.to(self.device_torch)
|
||||
# pipeline.to(self.device_torch)
|
||||
|
||||
with network:
|
||||
with torch.no_grad():
|
||||
@@ -744,6 +820,19 @@ class StableDiffusion:
|
||||
latents=gen_config.latents,
|
||||
**extra
|
||||
).images[0]
|
||||
elif self.is_v3:
|
||||
img = pipeline(
|
||||
prompt_embeds=conditional_embeds.text_embeds,
|
||||
pooled_prompt_embeds=conditional_embeds.pooled_embeds,
|
||||
negative_prompt_embeds=unconditional_embeds.text_embeds,
|
||||
negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds,
|
||||
height=gen_config.height,
|
||||
width=gen_config.width,
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
guidance_scale=gen_config.guidance_scale,
|
||||
latents=gen_config.latents,
|
||||
**extra
|
||||
).images[0]
|
||||
elif self.is_pixart:
|
||||
# needs attention masks for some reason
|
||||
img = pipeline(
|
||||
@@ -1004,6 +1093,20 @@ class StableDiffusion:
|
||||
)
|
||||
return torch.cat(out_chunks, dim=0)
|
||||
|
||||
def precondition_model_outputs_sd3(model_output, model_input, timestep_tensor):
|
||||
mo_chunks = torch.chunk(model_output, model_output.shape[0], dim=0)
|
||||
mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0)
|
||||
timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0)
|
||||
out_chunks = []
|
||||
# unsqueeze if timestep is zero dim
|
||||
for idx in range(model_output.shape[0]):
|
||||
sigmas = self.noise_scheduler.get_sigmas(timestep_chunks[idx], n_dim=model_output.ndim, dtype=model_output.dtype, device=model_output.device)
|
||||
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
|
||||
# Preconditioning of the model outputs.
|
||||
out = mo_chunks[idx] * (-sigmas) + mi_chunks[idx]
|
||||
out_chunks.append(out)
|
||||
return torch.cat(out_chunks, dim=0)
|
||||
|
||||
if self.is_xl:
|
||||
with torch.no_grad():
|
||||
# 16, 6 for bs of 4
|
||||
@@ -1177,12 +1280,22 @@ class StableDiffusion:
|
||||
self.unet.to(self.device_torch)
|
||||
if self.unet.dtype != self.torch_dtype:
|
||||
self.unet = self.unet.to(dtype=self.torch_dtype)
|
||||
noise_pred = self.unet(
|
||||
latent_model_input.to(self.device_torch, self.torch_dtype),
|
||||
timestep,
|
||||
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
|
||||
**kwargs,
|
||||
).sample
|
||||
if self.is_v3:
|
||||
noise_pred = self.unet(
|
||||
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
|
||||
pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, self.torch_dtype),
|
||||
**kwargs,
|
||||
).sample
|
||||
noise_pred = precondition_model_outputs_sd3(noise_pred, latent_model_input, timestep)
|
||||
else:
|
||||
noise_pred = self.unet(
|
||||
latent_model_input.to(self.device_torch, self.torch_dtype),
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
|
||||
**kwargs,
|
||||
).sample
|
||||
|
||||
conditional_pred = noise_pred
|
||||
|
||||
@@ -1343,6 +1456,19 @@ class StableDiffusion:
|
||||
dropout_prob=dropout_prob,
|
||||
)
|
||||
)
|
||||
if self.is_v3:
|
||||
return PromptEmbeds(
|
||||
train_tools.encode_prompts_sd3(
|
||||
self.tokenizer,
|
||||
self.text_encoder,
|
||||
prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
truncate=not long_prompts,
|
||||
max_length=max_length,
|
||||
dropout_prob=dropout_prob,
|
||||
pipeline=self.pipeline,
|
||||
)
|
||||
)
|
||||
elif self.is_pixart:
|
||||
embeds, attention_mask = train_tools.encode_prompts_pixart(
|
||||
self.tokenizer,
|
||||
@@ -1735,7 +1861,7 @@ class StableDiffusion:
|
||||
def save_device_state(self):
|
||||
# saves the current device state for all modules
|
||||
# this is useful for when we want to alter the state and restore it
|
||||
if self.is_pixart:
|
||||
if self.is_pixart or self.is_v3:
|
||||
unet_has_grad = self.unet.proj_out.weight.requires_grad
|
||||
else:
|
||||
unet_has_grad = self.unet.conv_in.weight.requires_grad
|
||||
@@ -1755,11 +1881,15 @@ class StableDiffusion:
|
||||
if isinstance(self.text_encoder, list):
|
||||
self.device_state['text_encoder']: List[dict] = []
|
||||
for encoder in self.text_encoder:
|
||||
try:
|
||||
te_has_grad = encoder.text_model.final_layer_norm.weight.requires_grad
|
||||
except:
|
||||
te_has_grad = encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad
|
||||
self.device_state['text_encoder'].append({
|
||||
'training': encoder.training,
|
||||
'device': encoder.device,
|
||||
# todo there has to be a better way to do this
|
||||
'requires_grad': encoder.text_model.final_layer_norm.weight.requires_grad
|
||||
'requires_grad': te_has_grad
|
||||
})
|
||||
else:
|
||||
if isinstance(self.text_encoder, T5EncoderModel):
|
||||
|
||||
@@ -25,6 +25,7 @@ from diffusers import (
|
||||
HeunDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
StableDiffusion3Pipeline
|
||||
)
|
||||
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
|
||||
import torch
|
||||
@@ -580,6 +581,58 @@ def encode_prompts_xl(
|
||||
|
||||
return torch.concat(text_embeds_list, dim=-1), pooled_text_embeds
|
||||
|
||||
def encode_prompts_sd3(
|
||||
tokenizers: list['CLIPTokenizer'],
|
||||
text_encoders: list[Union['CLIPTextModel', 'CLIPTextModelWithProjection', T5EncoderModel]],
|
||||
prompts: list[str],
|
||||
num_images_per_prompt: int = 1,
|
||||
truncate: bool = True,
|
||||
max_length=None,
|
||||
dropout_prob=0.0,
|
||||
pipeline: StableDiffusion3Pipeline = None,
|
||||
):
|
||||
text_embeds_list = []
|
||||
pooled_text_embeds = None # always text_encoder_2's pool
|
||||
|
||||
prompt_2 = prompts
|
||||
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
||||
|
||||
prompt_3 = prompts
|
||||
prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
|
||||
|
||||
device = text_encoders[0].device
|
||||
|
||||
prompt_embed, pooled_prompt_embed = pipeline._get_clip_prompt_embeds(
|
||||
prompt=prompts,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
clip_skip=None,
|
||||
clip_model_index=0,
|
||||
)
|
||||
prompt_2_embed, pooled_prompt_2_embed = pipeline._get_clip_prompt_embeds(
|
||||
prompt=prompt_2,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
clip_skip=None,
|
||||
clip_model_index=1,
|
||||
)
|
||||
clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
|
||||
|
||||
t5_prompt_embed = pipeline._get_t5_prompt_embeds(
|
||||
prompt=prompt_3,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device
|
||||
)
|
||||
|
||||
clip_prompt_embeds = torch.nn.functional.pad(
|
||||
clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
|
||||
)
|
||||
|
||||
prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
|
||||
pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
|
||||
# ref for long prompts https://github.com/huggingface/diffusers/issues/2136
|
||||
def text_encode(text_encoder: 'CLIPTextModel', tokens, truncate: bool = True, max_length=None):
|
||||
@@ -720,18 +773,22 @@ def concat_embeddings(
|
||||
|
||||
|
||||
def add_all_snr_to_noise_scheduler(noise_scheduler, device):
|
||||
if hasattr(noise_scheduler, "all_snr"):
|
||||
return
|
||||
# compute it
|
||||
with torch.no_grad():
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod
|
||||
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
||||
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
|
||||
alpha = sqrt_alphas_cumprod
|
||||
sigma = sqrt_one_minus_alphas_cumprod
|
||||
all_snr = (alpha / sigma) ** 2
|
||||
all_snr.requires_grad = False
|
||||
noise_scheduler.all_snr = all_snr.to(device)
|
||||
try:
|
||||
if hasattr(noise_scheduler, "all_snr"):
|
||||
return
|
||||
# compute it
|
||||
with torch.no_grad():
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod
|
||||
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
||||
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
|
||||
alpha = sqrt_alphas_cumprod
|
||||
sigma = sqrt_one_minus_alphas_cumprod
|
||||
all_snr = (alpha / sigma) ** 2
|
||||
all_snr.requires_grad = False
|
||||
noise_scheduler.all_snr = all_snr.to(device)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("Failed to add all_snr to noise_scheduler")
|
||||
|
||||
|
||||
def get_all_snr(noise_scheduler, device):
|
||||
|
||||
Reference in New Issue
Block a user