8 bit training working on flux

This commit is contained in:
Jaret Burkett
2024-08-06 11:53:27 -06:00
parent 272c8608c2
commit c2424087d6
7 changed files with 82 additions and 31 deletions

View File

@@ -14,6 +14,7 @@ from PIL import Image
from diffusers.pipelines.pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_1024_BIN, ASPECT_RATIO_512_BIN, ASPECT_RATIO_2048_BIN, ASPECT_RATIO_256_BIN
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
from safetensors.torch import save_file, load_file
from torch import autocast
from torch.nn import Parameter
from torch.utils.checkpoint import checkpoint
from tqdm import tqdm
@@ -54,7 +55,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjecti
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
from toolkit.util.inverse_cfg import inverse_classifier_guidance
from optimum.quanto import freeze, qfloat8, quantize
from optimum.quanto import freeze, qfloat8, quantize, QTensor
# tell it to shut up
diffusers.logging.set_verbosity(diffusers.logging.ERROR)
@@ -474,6 +475,23 @@ class StableDiffusion:
transformer.to(self.device_torch, dtype=dtype)
flush()
if self.model_config.lora_path is not None:
# need the pipe to do this unfortunately for now
# we have to fuse in the weights before quantizing
pipe: FluxPipeline = FluxPipeline(
scheduler=scheduler,
text_encoder=None,
tokenizer=None,
text_encoder_2=None,
tokenizer_2=None,
vae=vae,
transformer=transformer,
)
pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1")
pipe.fuse_lora()
# unfortunately, not an easier way with peft
pipe.unload_lora_weights()
if self.model_config.quantize:
print("Quantizing transformer")
quantize(transformer, weights=qfloat8)
@@ -498,7 +516,7 @@ class StableDiffusion:
text_encoder.to(self.device_torch, dtype=dtype)
print("making pipe")
pipe = FluxPipeline(
pipe: FluxPipeline = FluxPipeline(
scheduler=scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
@@ -613,7 +631,7 @@ class StableDiffusion:
self.unet.eval()
# load any loras we have
if self.model_config.lora_path is not None:
if self.model_config.lora_path is not None and not self.is_flux:
pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1")
pipe.fuse_lora()
# unfortunately, not an easier way with peft
@@ -1631,14 +1649,15 @@ class StableDiffusion:
width=width_latent, # 128
)
cast_dtype = self.unet.dtype
# with torch.amp.autocast(device_type='cuda', dtype=cast_dtype):
noise_pred = self.unet(
hidden_states=latent_model_input_packed.to(self.device_torch, self.torch_dtype), # [1, 4096, 64]
hidden_states=latent_model_input_packed.to(self.device_torch, cast_dtype), # [1, 4096, 64]
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
# todo make sure this doesnt change
timestep=timestep / 1000, # timestep is 1000 scale
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), # [1, 512, 4096]
pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, self.torch_dtype), # [1, 768]
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, cast_dtype), # [1, 512, 4096]
pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, cast_dtype), # [1, 768]
txt_ids=text_ids, # [1, 512, 3]
img_ids=latent_image_ids, # [1, 4096, 3]
guidance=guidance,
@@ -1646,6 +1665,9 @@ class StableDiffusion:
**kwargs,
)[0]
if isinstance(noise_pred, QTensor):
noise_pred = noise_pred.dequantize()
# unpack latents
noise_pred = self.pipeline._unpack_latents(
noise_pred,