mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Send more data when loading the model to the ui
This commit is contained in:
@@ -203,6 +203,7 @@ class StableDiffusion:
|
||||
# merge in and preview active with -1 weight
|
||||
self.invert_assistant_lora = False
|
||||
self._after_sample_img_hooks = []
|
||||
self._status_update_hooks = []
|
||||
|
||||
def load_model(self):
|
||||
if self.is_loaded:
|
||||
@@ -541,10 +542,10 @@ class StableDiffusion:
|
||||
tokenizer = pipe.tokenizer
|
||||
|
||||
elif self.model_config.is_flux:
|
||||
print_acc("Loading Flux model")
|
||||
self.print_and_status_update("Loading Flux model")
|
||||
# base_model_path = "black-forest-labs/FLUX.1-schnell"
|
||||
base_model_path = self.model_config.name_or_path_original
|
||||
print_acc("Loading transformer")
|
||||
self.print_and_status_update("Loading transformer")
|
||||
subfolder = 'transformer'
|
||||
transformer_path = model_path
|
||||
local_files_only = False
|
||||
@@ -689,7 +690,7 @@ class StableDiffusion:
|
||||
# patch the state dict method
|
||||
patch_dequantization_on_save(transformer)
|
||||
quantization_type = qfloat8
|
||||
print_acc("Quantizing transformer")
|
||||
self.print_and_status_update("Quantizing transformer")
|
||||
quantize(transformer, weights=quantization_type, **self.model_config.quantize_kwargs)
|
||||
freeze(transformer)
|
||||
transformer.to(self.device_torch)
|
||||
@@ -699,7 +700,7 @@ class StableDiffusion:
|
||||
flush()
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler")
|
||||
print_acc("Loading vae")
|
||||
self.print_and_status_update("Loading vae")
|
||||
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
|
||||
flush()
|
||||
|
||||
@@ -708,7 +709,7 @@ class StableDiffusion:
|
||||
text_encoder_2 = AutoModel.from_pretrained(base_model_path, subfolder="text_encoder_2", torch_dtype=dtype)
|
||||
|
||||
else:
|
||||
print_acc("Loading t5")
|
||||
self.print_and_status_update("Loading t5")
|
||||
tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2",
|
||||
torch_dtype=dtype)
|
||||
@@ -718,19 +719,19 @@ class StableDiffusion:
|
||||
|
||||
if self.model_config.quantize_te:
|
||||
if self.is_flex2:
|
||||
print_acc("Quantizing LLM")
|
||||
self.print_and_status_update("Quantizing LLM")
|
||||
else:
|
||||
print_acc("Quantizing T5")
|
||||
self.print_and_status_update("Quantizing T5")
|
||||
quantize(text_encoder_2, weights=qfloat8)
|
||||
freeze(text_encoder_2)
|
||||
flush()
|
||||
|
||||
print_acc("Loading clip")
|
||||
self.print_and_status_update("Loading clip")
|
||||
text_encoder = CLIPTextModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype)
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
|
||||
print_acc("making pipe")
|
||||
self.print_and_status_update("Making pipe")
|
||||
Pipe = FluxPipeline
|
||||
if self.is_flex2:
|
||||
Pipe = Flex2Pipeline
|
||||
@@ -747,7 +748,7 @@ class StableDiffusion:
|
||||
pipe.text_encoder_2 = text_encoder_2
|
||||
pipe.transformer = transformer
|
||||
|
||||
print_acc("preparing")
|
||||
self.print_and_status_update("Preparing Model")
|
||||
|
||||
text_encoder = [pipe.text_encoder, pipe.text_encoder_2]
|
||||
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
|
||||
@@ -764,10 +765,10 @@ class StableDiffusion:
|
||||
pipe.transformer = pipe.transformer.to(self.device_torch)
|
||||
flush()
|
||||
elif self.model_config.is_lumina2:
|
||||
print_acc("Loading Lumina2 model")
|
||||
self.print_and_status_update("Loading Lumina2 model")
|
||||
# base_model_path = "black-forest-labs/FLUX.1-schnell"
|
||||
base_model_path = self.model_config.name_or_path_original
|
||||
print_acc("Loading transformer")
|
||||
self.print_and_status_update("Loading transformer")
|
||||
subfolder = 'transformer'
|
||||
transformer_path = model_path
|
||||
if os.path.exists(transformer_path):
|
||||
@@ -803,7 +804,7 @@ class StableDiffusion:
|
||||
# patch the state dict method
|
||||
patch_dequantization_on_save(transformer)
|
||||
quantization_type = qfloat8
|
||||
print_acc("Quantizing transformer")
|
||||
self.print_and_status_update("Quantizing transformer")
|
||||
quantize(transformer, weights=quantization_type, **self.model_config.quantize_kwargs)
|
||||
freeze(transformer)
|
||||
transformer.to(self.device_torch)
|
||||
@@ -813,16 +814,16 @@ class StableDiffusion:
|
||||
flush()
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler")
|
||||
print_acc("Loading vae")
|
||||
self.print_and_status_update("Loading vae")
|
||||
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.te_name_or_path is not None:
|
||||
print_acc("Loading TE")
|
||||
self.print_and_status_update("Loading TE")
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_config.te_name_or_path, torch_dtype=dtype)
|
||||
text_encoder = AutoModel.from_pretrained(self.model_config.te_name_or_path, torch_dtype=dtype)
|
||||
else:
|
||||
print_acc("Loading Gemma2")
|
||||
self.print_and_status_update("Loading Gemma2")
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype)
|
||||
text_encoder = AutoModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype)
|
||||
|
||||
@@ -830,12 +831,12 @@ class StableDiffusion:
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize_te:
|
||||
print_acc("Quantizing Gemma2")
|
||||
self.print_and_status_update("Quantizing Gemma2")
|
||||
quantize(text_encoder, weights=qfloat8)
|
||||
freeze(text_encoder)
|
||||
flush()
|
||||
|
||||
print_acc("making pipe")
|
||||
self.print_and_status_update("Making pipe")
|
||||
pipe: Lumina2Text2ImgPipeline = Lumina2Text2ImgPipeline(
|
||||
scheduler=scheduler,
|
||||
text_encoder=None,
|
||||
@@ -846,7 +847,7 @@ class StableDiffusion:
|
||||
pipe.text_encoder = text_encoder
|
||||
pipe.transformer = transformer
|
||||
|
||||
print_acc("preparing")
|
||||
self.print_and_status_update("Preparing Model")
|
||||
|
||||
text_encoder = pipe.text_encoder
|
||||
tokenizer = pipe.tokenizer
|
||||
@@ -1041,6 +1042,17 @@ class StableDiffusion:
|
||||
|
||||
def add_after_sample_image_hook(self, func):
|
||||
self._after_sample_img_hooks.append(func)
|
||||
|
||||
def _status_update(self, status: str):
|
||||
for hook in self._status_update_hooks:
|
||||
hook(status)
|
||||
|
||||
def print_and_status_update(self, status: str):
|
||||
print_acc(status)
|
||||
self._status_update(status)
|
||||
|
||||
def add_status_update_hook(self, func):
|
||||
self._status_update_hooks.append(func)
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_images(
|
||||
|
||||
Reference in New Issue
Block a user