mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Capture speed from the timer for the ui
This commit is contained in:
@@ -700,7 +700,7 @@ class StableDiffusion:
|
||||
flush()
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler")
|
||||
self.print_and_status_update("Loading vae")
|
||||
self.print_and_status_update("Loading VAE")
|
||||
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
|
||||
flush()
|
||||
|
||||
@@ -709,7 +709,7 @@ class StableDiffusion:
|
||||
text_encoder_2 = AutoModel.from_pretrained(base_model_path, subfolder="text_encoder_2", torch_dtype=dtype)
|
||||
|
||||
else:
|
||||
self.print_and_status_update("Loading t5")
|
||||
self.print_and_status_update("Loading T5")
|
||||
tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2",
|
||||
torch_dtype=dtype)
|
||||
@@ -726,7 +726,7 @@ class StableDiffusion:
|
||||
freeze(text_encoder_2)
|
||||
flush()
|
||||
|
||||
self.print_and_status_update("Loading clip")
|
||||
self.print_and_status_update("Loading CLIP")
|
||||
text_encoder = CLIPTextModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype)
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
|
||||
@@ -9,6 +9,7 @@ class Timer:
|
||||
self.timers = OrderedDict()
|
||||
self.active_timers = {}
|
||||
self.current_timer = None # Used for the context manager functionality
|
||||
self._after_print_hooks = []
|
||||
|
||||
def start(self, timer_name):
|
||||
if timer_name not in self.timers:
|
||||
@@ -34,12 +35,20 @@ class Timer:
|
||||
if len(self.timers[timer_name]) > self.max_buffer:
|
||||
self.timers[timer_name].popleft()
|
||||
|
||||
def add_after_print_hook(self, hook):
|
||||
self._after_print_hooks.append(hook)
|
||||
|
||||
def print(self):
|
||||
print(f"\nTimer '{self.name}':")
|
||||
timing_dict = {}
|
||||
# sort by longest at top
|
||||
for timer_name, timings in sorted(self.timers.items(), key=lambda x: sum(x[1]), reverse=True):
|
||||
avg_time = sum(timings) / len(timings)
|
||||
print(f" - {avg_time:.4f}s avg - {timer_name}, num = {len(timings)}")
|
||||
timing_dict[timer_name] = avg_time
|
||||
|
||||
for hook in self._after_print_hooks:
|
||||
hook(timing_dict)
|
||||
|
||||
print('')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user