From 60f848a877f8665e49b24aa7af1b1483700f5022 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 23 Feb 2025 12:49:54 -0700 Subject: [PATCH] Send more data when loading the model to the ui --- extensions_built_in/sd_trainer/UITrainer.py | 9 ++++ jobs/process/BaseSDTrainProcess.py | 5 +++ toolkit/stable_diffusion_model.py | 50 +++++++++++++-------- 3 files changed, 45 insertions(+), 19 deletions(-) diff --git a/extensions_built_in/sd_trainer/UITrainer.py b/extensions_built_in/sd_trainer/UITrainer.py index 7a34337b..2657f60f 100644 --- a/extensions_built_in/sd_trainer/UITrainer.py +++ b/extensions_built_in/sd_trainer/UITrainer.py @@ -155,6 +155,15 @@ class UITrainer(SDTrainer): super().hook_before_train_loop() self.maybe_stop() self.update_status("running", "Training") + + def status_update_hook_func(self, string): + self.update_status("running", string) + + def hook_after_sd_init_before_load(self): + super().hook_after_sd_init_before_load() + self.maybe_stop() + self.sd.add_status_update_hook(self.status_update_hook_func) + def sample_step_hook(self, img_num, total_imgs): super().sample_step_hook(img_num, total_imgs) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 150e1e2f..d22dd39d 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -730,6 +730,9 @@ class BaseSDTrainProcess(BaseTrainProcess): def hook_train_loop(self, batch): # return loss return 0.0 + + def hook_after_sd_init_before_load(self): + pass def get_latest_save_path(self, name=None, post=''): if name == None: @@ -1425,6 +1428,8 @@ class BaseSDTrainProcess(BaseTrainProcess): custom_pipeline=self.custom_pipeline, noise_scheduler=sampler, ) + + self.hook_after_sd_init_before_load() # run base sd process run self.sd.load_model() diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 4fe01f9d..40da202d 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -203,6 +203,7 @@ class StableDiffusion: # merge in and preview active with -1 weight self.invert_assistant_lora = False self._after_sample_img_hooks = [] + self._status_update_hooks = [] def load_model(self): if self.is_loaded: @@ -541,10 +542,10 @@ class StableDiffusion: tokenizer = pipe.tokenizer elif self.model_config.is_flux: - print_acc("Loading Flux model") + self.print_and_status_update("Loading Flux model") # base_model_path = "black-forest-labs/FLUX.1-schnell" base_model_path = self.model_config.name_or_path_original - print_acc("Loading transformer") + self.print_and_status_update("Loading transformer") subfolder = 'transformer' transformer_path = model_path local_files_only = False @@ -689,7 +690,7 @@ class StableDiffusion: # patch the state dict method patch_dequantization_on_save(transformer) quantization_type = qfloat8 - print_acc("Quantizing transformer") + self.print_and_status_update("Quantizing transformer") quantize(transformer, weights=quantization_type, **self.model_config.quantize_kwargs) freeze(transformer) transformer.to(self.device_torch) @@ -699,7 +700,7 @@ class StableDiffusion: flush() scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler") - print_acc("Loading vae") + self.print_and_status_update("Loading vae") vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype) flush() @@ -708,7 +709,7 @@ class StableDiffusion: text_encoder_2 = AutoModel.from_pretrained(base_model_path, subfolder="text_encoder_2", torch_dtype=dtype) else: - print_acc("Loading t5") + self.print_and_status_update("Loading t5") tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype) text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2", torch_dtype=dtype) @@ -718,19 +719,19 @@ class StableDiffusion: if self.model_config.quantize_te: if self.is_flex2: - print_acc("Quantizing LLM") + self.print_and_status_update("Quantizing LLM") else: - print_acc("Quantizing T5") + self.print_and_status_update("Quantizing T5") quantize(text_encoder_2, weights=qfloat8) freeze(text_encoder_2) flush() - print_acc("Loading clip") + self.print_and_status_update("Loading clip") text_encoder = CLIPTextModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype) tokenizer = CLIPTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype) text_encoder.to(self.device_torch, dtype=dtype) - print_acc("making pipe") + self.print_and_status_update("Making pipe") Pipe = FluxPipeline if self.is_flex2: Pipe = Flex2Pipeline @@ -747,7 +748,7 @@ class StableDiffusion: pipe.text_encoder_2 = text_encoder_2 pipe.transformer = transformer - print_acc("preparing") + self.print_and_status_update("Preparing Model") text_encoder = [pipe.text_encoder, pipe.text_encoder_2] tokenizer = [pipe.tokenizer, pipe.tokenizer_2] @@ -764,10 +765,10 @@ class StableDiffusion: pipe.transformer = pipe.transformer.to(self.device_torch) flush() elif self.model_config.is_lumina2: - print_acc("Loading Lumina2 model") + self.print_and_status_update("Loading Lumina2 model") # base_model_path = "black-forest-labs/FLUX.1-schnell" base_model_path = self.model_config.name_or_path_original - print_acc("Loading transformer") + self.print_and_status_update("Loading transformer") subfolder = 'transformer' transformer_path = model_path if os.path.exists(transformer_path): @@ -803,7 +804,7 @@ class StableDiffusion: # patch the state dict method patch_dequantization_on_save(transformer) quantization_type = qfloat8 - print_acc("Quantizing transformer") + self.print_and_status_update("Quantizing transformer") quantize(transformer, weights=quantization_type, **self.model_config.quantize_kwargs) freeze(transformer) transformer.to(self.device_torch) @@ -813,16 +814,16 @@ class StableDiffusion: flush() scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler") - print_acc("Loading vae") + self.print_and_status_update("Loading vae") vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype) flush() if self.model_config.te_name_or_path is not None: - print_acc("Loading TE") + self.print_and_status_update("Loading TE") tokenizer = AutoTokenizer.from_pretrained(self.model_config.te_name_or_path, torch_dtype=dtype) text_encoder = AutoModel.from_pretrained(self.model_config.te_name_or_path, torch_dtype=dtype) else: - print_acc("Loading Gemma2") + self.print_and_status_update("Loading Gemma2") tokenizer = AutoTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype) text_encoder = AutoModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype) @@ -830,12 +831,12 @@ class StableDiffusion: flush() if self.model_config.quantize_te: - print_acc("Quantizing Gemma2") + self.print_and_status_update("Quantizing Gemma2") quantize(text_encoder, weights=qfloat8) freeze(text_encoder) flush() - print_acc("making pipe") + self.print_and_status_update("Making pipe") pipe: Lumina2Text2ImgPipeline = Lumina2Text2ImgPipeline( scheduler=scheduler, text_encoder=None, @@ -846,7 +847,7 @@ class StableDiffusion: pipe.text_encoder = text_encoder pipe.transformer = transformer - print_acc("preparing") + self.print_and_status_update("Preparing Model") text_encoder = pipe.text_encoder tokenizer = pipe.tokenizer @@ -1041,6 +1042,17 @@ class StableDiffusion: def add_after_sample_image_hook(self, func): self._after_sample_img_hooks.append(func) + + def _status_update(self, status: str): + for hook in self._status_update_hooks: + hook(status) + + def print_and_status_update(self, status: str): + print_acc(status) + self._status_update(status) + + def add_status_update_hook(self, func): + self._status_update_hooks.append(func) @torch.no_grad() def generate_images(