mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-25 16:59:22 +00:00
8 bit training working on flux
This commit is contained in:
@@ -14,6 +14,7 @@ from PIL import Image
|
||||
from diffusers.pipelines.pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_1024_BIN, ASPECT_RATIO_512_BIN, ASPECT_RATIO_2048_BIN, ASPECT_RATIO_256_BIN
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
|
||||
from safetensors.torch import save_file, load_file
|
||||
from torch import autocast
|
||||
from torch.nn import Parameter
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from tqdm import tqdm
|
||||
@@ -54,7 +55,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjecti
|
||||
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
|
||||
from toolkit.util.inverse_cfg import inverse_classifier_guidance
|
||||
|
||||
from optimum.quanto import freeze, qfloat8, quantize
|
||||
from optimum.quanto import freeze, qfloat8, quantize, QTensor
|
||||
|
||||
# tell it to shut up
|
||||
diffusers.logging.set_verbosity(diffusers.logging.ERROR)
|
||||
@@ -474,6 +475,23 @@ class StableDiffusion:
|
||||
transformer.to(self.device_torch, dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.lora_path is not None:
|
||||
# need the pipe to do this unfortunately for now
|
||||
# we have to fuse in the weights before quantizing
|
||||
pipe: FluxPipeline = FluxPipeline(
|
||||
scheduler=scheduler,
|
||||
text_encoder=None,
|
||||
tokenizer=None,
|
||||
text_encoder_2=None,
|
||||
tokenizer_2=None,
|
||||
vae=vae,
|
||||
transformer=transformer,
|
||||
)
|
||||
pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1")
|
||||
pipe.fuse_lora()
|
||||
# unfortunately, not an easier way with peft
|
||||
pipe.unload_lora_weights()
|
||||
|
||||
if self.model_config.quantize:
|
||||
print("Quantizing transformer")
|
||||
quantize(transformer, weights=qfloat8)
|
||||
@@ -498,7 +516,7 @@ class StableDiffusion:
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
|
||||
print("making pipe")
|
||||
pipe = FluxPipeline(
|
||||
pipe: FluxPipeline = FluxPipeline(
|
||||
scheduler=scheduler,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
@@ -613,7 +631,7 @@ class StableDiffusion:
|
||||
self.unet.eval()
|
||||
|
||||
# load any loras we have
|
||||
if self.model_config.lora_path is not None:
|
||||
if self.model_config.lora_path is not None and not self.is_flux:
|
||||
pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1")
|
||||
pipe.fuse_lora()
|
||||
# unfortunately, not an easier way with peft
|
||||
@@ -1631,14 +1649,15 @@ class StableDiffusion:
|
||||
width=width_latent, # 128
|
||||
)
|
||||
|
||||
|
||||
cast_dtype = self.unet.dtype
|
||||
# with torch.amp.autocast(device_type='cuda', dtype=cast_dtype):
|
||||
noise_pred = self.unet(
|
||||
hidden_states=latent_model_input_packed.to(self.device_torch, self.torch_dtype), # [1, 4096, 64]
|
||||
hidden_states=latent_model_input_packed.to(self.device_torch, cast_dtype), # [1, 4096, 64]
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
# todo make sure this doesnt change
|
||||
timestep=timestep / 1000, # timestep is 1000 scale
|
||||
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), # [1, 512, 4096]
|
||||
pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, self.torch_dtype), # [1, 768]
|
||||
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, cast_dtype), # [1, 512, 4096]
|
||||
pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, cast_dtype), # [1, 768]
|
||||
txt_ids=text_ids, # [1, 512, 3]
|
||||
img_ids=latent_image_ids, # [1, 4096, 3]
|
||||
guidance=guidance,
|
||||
@@ -1646,6 +1665,9 @@ class StableDiffusion:
|
||||
**kwargs,
|
||||
)[0]
|
||||
|
||||
if isinstance(noise_pred, QTensor):
|
||||
noise_pred = noise_pred.dequantize()
|
||||
|
||||
# unpack latents
|
||||
noise_pred = self.pipeline._unpack_latents(
|
||||
noise_pred,
|
||||
|
||||
Reference in New Issue
Block a user