Send more data when loading the model to the ui

This commit is contained in:
Jaret Burkett
2025-02-23 12:49:54 -07:00
parent b366e46f1c
commit 60f848a877
3 changed files with 45 additions and 19 deletions

View File

@@ -203,6 +203,7 @@ class StableDiffusion:
# merge in and preview active with -1 weight
self.invert_assistant_lora = False
self._after_sample_img_hooks = []
self._status_update_hooks = []
def load_model(self):
if self.is_loaded:
@@ -541,10 +542,10 @@ class StableDiffusion:
tokenizer = pipe.tokenizer
elif self.model_config.is_flux:
print_acc("Loading Flux model")
self.print_and_status_update("Loading Flux model")
# base_model_path = "black-forest-labs/FLUX.1-schnell"
base_model_path = self.model_config.name_or_path_original
print_acc("Loading transformer")
self.print_and_status_update("Loading transformer")
subfolder = 'transformer'
transformer_path = model_path
local_files_only = False
@@ -689,7 +690,7 @@ class StableDiffusion:
# patch the state dict method
patch_dequantization_on_save(transformer)
quantization_type = qfloat8
print_acc("Quantizing transformer")
self.print_and_status_update("Quantizing transformer")
quantize(transformer, weights=quantization_type, **self.model_config.quantize_kwargs)
freeze(transformer)
transformer.to(self.device_torch)
@@ -699,7 +700,7 @@ class StableDiffusion:
flush()
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler")
print_acc("Loading vae")
self.print_and_status_update("Loading vae")
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
flush()
@@ -708,7 +709,7 @@ class StableDiffusion:
text_encoder_2 = AutoModel.from_pretrained(base_model_path, subfolder="text_encoder_2", torch_dtype=dtype)
else:
print_acc("Loading t5")
self.print_and_status_update("Loading t5")
tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype)
text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2",
torch_dtype=dtype)
@@ -718,19 +719,19 @@ class StableDiffusion:
if self.model_config.quantize_te:
if self.is_flex2:
print_acc("Quantizing LLM")
self.print_and_status_update("Quantizing LLM")
else:
print_acc("Quantizing T5")
self.print_and_status_update("Quantizing T5")
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)
flush()
print_acc("Loading clip")
self.print_and_status_update("Loading clip")
text_encoder = CLIPTextModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype)
tokenizer = CLIPTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype)
text_encoder.to(self.device_torch, dtype=dtype)
print_acc("making pipe")
self.print_and_status_update("Making pipe")
Pipe = FluxPipeline
if self.is_flex2:
Pipe = Flex2Pipeline
@@ -747,7 +748,7 @@ class StableDiffusion:
pipe.text_encoder_2 = text_encoder_2
pipe.transformer = transformer
print_acc("preparing")
self.print_and_status_update("Preparing Model")
text_encoder = [pipe.text_encoder, pipe.text_encoder_2]
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
@@ -764,10 +765,10 @@ class StableDiffusion:
pipe.transformer = pipe.transformer.to(self.device_torch)
flush()
elif self.model_config.is_lumina2:
print_acc("Loading Lumina2 model")
self.print_and_status_update("Loading Lumina2 model")
# base_model_path = "black-forest-labs/FLUX.1-schnell"
base_model_path = self.model_config.name_or_path_original
print_acc("Loading transformer")
self.print_and_status_update("Loading transformer")
subfolder = 'transformer'
transformer_path = model_path
if os.path.exists(transformer_path):
@@ -803,7 +804,7 @@ class StableDiffusion:
# patch the state dict method
patch_dequantization_on_save(transformer)
quantization_type = qfloat8
print_acc("Quantizing transformer")
self.print_and_status_update("Quantizing transformer")
quantize(transformer, weights=quantization_type, **self.model_config.quantize_kwargs)
freeze(transformer)
transformer.to(self.device_torch)
@@ -813,16 +814,16 @@ class StableDiffusion:
flush()
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler")
print_acc("Loading vae")
self.print_and_status_update("Loading vae")
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
flush()
if self.model_config.te_name_or_path is not None:
print_acc("Loading TE")
self.print_and_status_update("Loading TE")
tokenizer = AutoTokenizer.from_pretrained(self.model_config.te_name_or_path, torch_dtype=dtype)
text_encoder = AutoModel.from_pretrained(self.model_config.te_name_or_path, torch_dtype=dtype)
else:
print_acc("Loading Gemma2")
self.print_and_status_update("Loading Gemma2")
tokenizer = AutoTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype)
text_encoder = AutoModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype)
@@ -830,12 +831,12 @@ class StableDiffusion:
flush()
if self.model_config.quantize_te:
print_acc("Quantizing Gemma2")
self.print_and_status_update("Quantizing Gemma2")
quantize(text_encoder, weights=qfloat8)
freeze(text_encoder)
flush()
print_acc("making pipe")
self.print_and_status_update("Making pipe")
pipe: Lumina2Text2ImgPipeline = Lumina2Text2ImgPipeline(
scheduler=scheduler,
text_encoder=None,
@@ -846,7 +847,7 @@ class StableDiffusion:
pipe.text_encoder = text_encoder
pipe.transformer = transformer
print_acc("preparing")
self.print_and_status_update("Preparing Model")
text_encoder = pipe.text_encoder
tokenizer = pipe.tokenizer
@@ -1041,6 +1042,17 @@ class StableDiffusion:
def add_after_sample_image_hook(self, func):
self._after_sample_img_hooks.append(func)
def _status_update(self, status: str):
for hook in self._status_update_hooks:
hook(status)
def print_and_status_update(self, status: str):
print_acc(status)
self._status_update(status)
def add_status_update_hook(self, func):
self._status_update_hooks.append(func)
@torch.no_grad()
def generate_images(