Some work on sd3 training. Not working

This commit is contained in:
Jaret Burkett
2024-06-13 12:19:16 -06:00
parent cb5d28cba9
commit bd10d2d668
12 changed files with 306 additions and 36 deletions

View File

@@ -40,13 +40,13 @@ from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffu
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, \
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline
import diffusers
from diffusers import \
AutoencoderKL, \
UNet2DConditionModel
from diffusers import PixArtAlphaPipeline, DPMSolverMultistepScheduler
from transformers import T5EncoderModel
from transformers import T5EncoderModel, BitsAndBytesConfig
from toolkit.util.pixart_sigma_patch import pixart_sigma_init_patched_inputs, PixArtSigmaPipeline
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
@@ -147,6 +147,7 @@ class StableDiffusion:
self.is_xl = model_config.is_xl
self.is_v2 = model_config.is_v2
self.is_ssd = model_config.is_ssd
self.is_v3 = model_config.is_v3
self.is_vega = model_config.is_vega
self.is_pixart = model_config.is_pixart
@@ -236,6 +237,64 @@ class StableDiffusion:
te1_state_dict['text_projection.weight'] = replacement_weight.to(self.device_torch, dtype=dtype)
flush()
print("Injecting alt weights")
elif self.model_config.is_v3:
if self.custom_pipeline is not None:
pipln = self.custom_pipeline
else:
pipln = StableDiffusion3Pipeline
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model_id = "stabilityai/stable-diffusion-3-medium"
text_encoder3 = T5EncoderModel.from_pretrained(
model_id,
subfolder="text_encoder_3",
# quantization_config=quantization_config,
revision="refs/pr/26",
device_map="cuda"
)
# see if path exists
if not os.path.exists(model_path) or os.path.isdir(model_path):
try:
# try to load with default diffusers
pipe = pipln.from_pretrained(
model_path,
dtype=dtype,
device=self.device_torch,
text_encoder_3=text_encoder3,
# variant="fp16",
use_safetensors=True,
revision="refs/pr/26",
repo_type="model",
ignore_patterns=["*.md", "*..gitattributes"],
**load_args
)
except Exception as e:
print(f"Error loading from pretrained: {e}")
raise e
else:
pipe = pipln.from_single_file(
model_path,
device=self.device_torch,
torch_dtype=self.torch_dtype,
text_encoder_3=text_encoder3,
)
flush()
text_encoders = [pipe.text_encoder, pipe.text_encoder_2, pipe.text_encoder_3]
tokenizer = [pipe.tokenizer, pipe.tokenizer_2, pipe.tokenizer_3]
# replace the to function with a no-op since it throws an error instead of a warning
# text_encoders[2].to = lambda *args, **kwargs: None
for text_encoder in text_encoders:
text_encoder.to(self.device_torch, dtype=dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
text_encoder = text_encoders
elif self.model_config.is_pixart:
te_kwargs = {}
# handle quantization of TE
@@ -361,8 +420,8 @@ class StableDiffusion:
# add hacks to unet to help training
# pipe.unet = prepare_unet_for_training(pipe.unet)
if self.is_pixart:
# pixart doesnt use a unet
if self.is_pixart or self.is_v3:
# pixart and sd3 dont use a unet
self.unet = pipe.transformer
else:
self.unet: 'UNet2DConditionModel' = pipe.unet
@@ -487,6 +546,8 @@ class StableDiffusion:
Pipe = StableDiffusionKDiffusionXLPipeline
elif self.is_xl:
Pipe = StableDiffusionXLPipeline
elif self.is_v3:
Pipe = StableDiffusion3Pipeline
else:
Pipe = StableDiffusionPipeline
@@ -515,15 +576,30 @@ class StableDiffusion:
if self.is_xl:
pipeline = Pipe(
vae=self.vae,
unet=self.unet,
transformer=self.unet,
text_encoder=self.text_encoder[0],
text_encoder_2=self.text_encoder[1],
text_encoder_3=self.text_encoder[2],
tokenizer=self.tokenizer[0],
tokenizer_2=self.tokenizer[1],
tokenizer_3=self.tokenizer[2],
scheduler=noise_scheduler,
**extra_args
).to(self.device_torch)
pipeline.watermark = None
elif self.is_v3:
pipeline = Pipe(
vae=self.vae,
transformer=self.unet,
text_encoder=self.text_encoder[0],
text_encoder_2=self.text_encoder[1],
text_encoder_3=self.text_encoder[2],
tokenizer=self.tokenizer[0],
tokenizer_2=self.tokenizer[1],
tokenizer_3=self.tokenizer[2],
scheduler=noise_scheduler,
**extra_args
)
elif self.is_pixart:
pipeline = PixArtAlphaPipeline(
vae=self.vae,
@@ -576,7 +652,7 @@ class StableDiffusion:
if self.network is not None:
start_multiplier = self.network.multiplier
pipeline.to(self.device_torch)
# pipeline.to(self.device_torch)
with network:
with torch.no_grad():
@@ -744,6 +820,19 @@ class StableDiffusion:
latents=gen_config.latents,
**extra
).images[0]
elif self.is_v3:
img = pipeline(
prompt_embeds=conditional_embeds.text_embeds,
pooled_prompt_embeds=conditional_embeds.pooled_embeds,
negative_prompt_embeds=unconditional_embeds.text_embeds,
negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds,
height=gen_config.height,
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
latents=gen_config.latents,
**extra
).images[0]
elif self.is_pixart:
# needs attention masks for some reason
img = pipeline(
@@ -1004,6 +1093,20 @@ class StableDiffusion:
)
return torch.cat(out_chunks, dim=0)
def precondition_model_outputs_sd3(model_output, model_input, timestep_tensor):
mo_chunks = torch.chunk(model_output, model_output.shape[0], dim=0)
mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0)
timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0)
out_chunks = []
# unsqueeze if timestep is zero dim
for idx in range(model_output.shape[0]):
sigmas = self.noise_scheduler.get_sigmas(timestep_chunks[idx], n_dim=model_output.ndim, dtype=model_output.dtype, device=model_output.device)
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
# Preconditioning of the model outputs.
out = mo_chunks[idx] * (-sigmas) + mi_chunks[idx]
out_chunks.append(out)
return torch.cat(out_chunks, dim=0)
if self.is_xl:
with torch.no_grad():
# 16, 6 for bs of 4
@@ -1177,12 +1280,22 @@ class StableDiffusion:
self.unet.to(self.device_torch)
if self.unet.dtype != self.torch_dtype:
self.unet = self.unet.to(dtype=self.torch_dtype)
noise_pred = self.unet(
latent_model_input.to(self.device_torch, self.torch_dtype),
timestep,
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
**kwargs,
).sample
if self.is_v3:
noise_pred = self.unet(
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
timestep=timestep,
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, self.torch_dtype),
**kwargs,
).sample
noise_pred = precondition_model_outputs_sd3(noise_pred, latent_model_input, timestep)
else:
noise_pred = self.unet(
latent_model_input.to(self.device_torch, self.torch_dtype),
timestep=timestep,
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
**kwargs,
).sample
conditional_pred = noise_pred
@@ -1343,6 +1456,19 @@ class StableDiffusion:
dropout_prob=dropout_prob,
)
)
if self.is_v3:
return PromptEmbeds(
train_tools.encode_prompts_sd3(
self.tokenizer,
self.text_encoder,
prompt,
num_images_per_prompt=num_images_per_prompt,
truncate=not long_prompts,
max_length=max_length,
dropout_prob=dropout_prob,
pipeline=self.pipeline,
)
)
elif self.is_pixart:
embeds, attention_mask = train_tools.encode_prompts_pixart(
self.tokenizer,
@@ -1735,7 +1861,7 @@ class StableDiffusion:
def save_device_state(self):
# saves the current device state for all modules
# this is useful for when we want to alter the state and restore it
if self.is_pixart:
if self.is_pixart or self.is_v3:
unet_has_grad = self.unet.proj_out.weight.requires_grad
else:
unet_has_grad = self.unet.conv_in.weight.requires_grad
@@ -1755,11 +1881,15 @@ class StableDiffusion:
if isinstance(self.text_encoder, list):
self.device_state['text_encoder']: List[dict] = []
for encoder in self.text_encoder:
try:
te_has_grad = encoder.text_model.final_layer_norm.weight.requires_grad
except:
te_has_grad = encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad
self.device_state['text_encoder'].append({
'training': encoder.training,
'device': encoder.device,
# todo there has to be a better way to do this
'requires_grad': encoder.text_model.final_layer_norm.weight.requires_grad
'requires_grad': te_has_grad
})
else:
if isinstance(self.text_encoder, T5EncoderModel):