From 47f7dc6e30b7ffe2ac2a44b3f591bb3448da9fae Mon Sep 17 00:00:00 2001 From: Firetheft <97147055+Firetheft@users.noreply.github.com> Date: Tue, 14 Jan 2025 21:56:52 +0800 Subject: [PATCH] Fix enabling and disabling Remove unnecessary cache cleanup button --- scripts/teacache.py | 107 +++++++++++++++++--------------------------- 1 file changed, 42 insertions(+), 65 deletions(-) diff --git a/scripts/teacache.py b/scripts/teacache.py index f557b41..7f385af 100644 --- a/scripts/teacache.py +++ b/scripts/teacache.py @@ -22,14 +22,8 @@ class TeaCache(scripts.Script): return scripts.AlwaysVisible def ui(self, is_img2img): - with InputAccordion(False, label=self.title(), elem_id="extensions-teacache") as teacache_enabled: - with gr.Row(): - enable_teacache_checkbox = gr.Checkbox( - label="Enable TeaCache", - value=self.enable_teacache, - tooltip="Enable TeaCache to speed up inference by caching intermediate results." - ) - with gr.Row(visible=False) as teacache_settings: # Hide settings by default + with InputAccordion(value=False, label=self.title()) as enable: + with gr.Group(elem_classes="teacache"): rel_l1_thresh_slider = gr.Slider( label="Relative L1 Threshold", minimum=0.0, @@ -46,35 +40,20 @@ class TeaCache(scripts.Script): value=self.steps, tooltip="Number of steps to cache intermediate results." ) - with gr.Row(): - clear_cache_button = gr.Button("Clear Residual Cache", variant="secondary") - with gr.Row(): - gr.Markdown("**Note**: Clear residual cache when changing image size or disabling TeaCache.") - # Define a function to toggle the visibility of TeaCache settings - def toggle_teacache_settings(enable_teacache): - return {teacache_settings: gr.update(visible=enable_teacache)} - # Bind the checkbox change event to toggle settings visibility - enable_teacache_checkbox.change( - fn=toggle_teacache_settings, - inputs=enable_teacache_checkbox, - outputs=teacache_settings - ) + self.paste_field_names = [] + self.infotext_fields = [ + (enable, "TeaCache Enabled"), + (rel_l1_thresh_slider, "TeaCache Relative L1 Threshold"), + (steps_slider, "TeaCache Steps"), + ] - # Define the function to clear residual cache - def clear_residual_cache_ui(): - self.clear_residual_cache() - return "Residual cache cleared." + for comp, name in self.infotext_fields: + comp.do_not_save_to_config = True + self.paste_field_names.append(name) - # Bind the button click event - clear_cache_button.click( - fn=clear_residual_cache_ui, - outputs=None - ) - - # Return UI components - return [teacache_enabled, enable_teacache_checkbox, rel_l1_thresh_slider, steps_slider] + return [enable, rel_l1_thresh_slider, steps_slider] def clear_residual_cache(self): """Clear residual cache and free GPU memory.""" @@ -92,7 +71,7 @@ class TeaCache(scripts.Script): torch.cuda.empty_cache() print("Residual cache cleared and GPU memory freed.") - def process(self, p, teacache_enabled, enable_teacache_checkbox, rel_l1_thresh_slider, steps_slider): + def process(self, p, enable, rel_l1_thresh_slider, steps_slider): # Get the current input dimensions current_input_shape = (p.width, p.height) if self.last_input_shape is not None and current_input_shape != self.last_input_shape: @@ -101,22 +80,21 @@ class TeaCache(scripts.Script): self.last_input_shape = current_input_shape # Debug information only when TeaCache is enabled - if teacache_enabled: - print("TeaCache enabled:", teacache_enabled) - print("Enable TeaCache checkbox:", enable_teacache_checkbox) + if enable: + print("TeaCache enabled:", enable) print("Relative L1 Threshold:", rel_l1_thresh_slider) print("Steps:", steps_slider) # If TeaCache is enabled, add parameters to generation parameters - if teacache_enabled: + if enable: p.extra_generation_params.update({ - "enable_teacache": enable_teacache_checkbox, + "enable_teacache": enable, "rel_l1_thresh": rel_l1_thresh_slider, "steps": steps_slider, }) # Dynamically modify class attributes of IntegratedFluxTransformer2DModel - setattr(IntegratedFluxTransformer2DModel, "enable_teacache", enable_teacache_checkbox) + setattr(IntegratedFluxTransformer2DModel, "enable_teacache", enable) setattr(IntegratedFluxTransformer2DModel, "cnt", 0) setattr(IntegratedFluxTransformer2DModel, "rel_l1_thresh", rel_l1_thresh_slider) setattr(IntegratedFluxTransformer2DModel, "steps", steps_slider) @@ -125,48 +103,47 @@ class TeaCache(scripts.Script): setattr(IntegratedFluxTransformer2DModel, "previous_residual", None) # Replace the original inner_forward method - if hasattr(IntegratedFluxTransformer2DModel, "inner_forward"): - # Save the original inner_forward method + if not hasattr(IntegratedFluxTransformer2DModel, "original_inner_forward"): setattr(IntegratedFluxTransformer2DModel, "original_inner_forward", IntegratedFluxTransformer2DModel.inner_forward) - # Replace with the new inner_forward method - IntegratedFluxTransformer2DModel.inner_forward = patched_inner_forward + IntegratedFluxTransformer2DModel.inner_forward = patched_inner_forward # Replace the original forward method if not hasattr(IntegratedFluxTransformer2DModel, "original_forward"): IntegratedFluxTransformer2DModel.original_forward = IntegratedFluxTransformer2DModel.forward - IntegratedFluxTransformer2DModel.forward = patched_forward + IntegratedFluxTransformer2DModel.forward = patched_forward + setattr(IntegratedFluxTransformer2DModel, "patched_forward", patched_forward) # Mark that we have replaced the forward method else: # If TeaCache is disabled, restore the original inner_forward method if hasattr(IntegratedFluxTransformer2DModel, "original_inner_forward"): IntegratedFluxTransformer2DModel.inner_forward = IntegratedFluxTransformer2DModel.original_inner_forward - # Clear residual cache and reset TeaCache attributes - self.clear_residual_cache() - setattr(IntegratedFluxTransformer2DModel, "enable_teacache", False) - setattr(IntegratedFluxTransformer2DModel, "cnt", 0) - setattr(IntegratedFluxTransformer2DModel, "rel_l1_thresh", 0.4) - setattr(IntegratedFluxTransformer2DModel, "steps", 25) - setattr(IntegratedFluxTransformer2DModel, "accumulated_rel_l1_distance", 0) - setattr(IntegratedFluxTransformer2DModel, "previous_modulated_input", None) - setattr(IntegratedFluxTransformer2DModel, "previous_residual", None) - print("TeaCache fully disabled and cache cleared.") - # Restore the original forward method - if hasattr(IntegratedFluxTransformer2DModel, "original_forward"): + # Restore the original forward method only if it was replaced and remove patched_forward flag + if hasattr(IntegratedFluxTransformer2DModel, "original_forward") and hasattr(IntegratedFluxTransformer2DModel, "patched_forward"): IntegratedFluxTransformer2DModel.forward = IntegratedFluxTransformer2DModel.original_forward - # Remove the patched forward method to avoid recursion - if hasattr(IntegratedFluxTransformer2DModel, "patched_forward"): - delattr(IntegratedFluxTransformer2DModel, "patched_forward") + delattr(IntegratedFluxTransformer2DModel, "patched_forward") + + # Clear residual cache and reset TeaCache attributes + self.clear_residual_cache() + setattr(IntegratedFluxTransformer2DModel, "enable_teacache", False) + setattr(IntegratedFluxTransformer2DModel, "cnt", 0) + setattr(IntegratedFluxTransformer2DModel, "rel_l1_thresh", 0.4) + setattr(IntegratedFluxTransformer2DModel, "steps", 25) + setattr(IntegratedFluxTransformer2DModel, "accumulated_rel_l1_distance", 0) + setattr(IntegratedFluxTransformer2DModel, "previous_modulated_input", None) + setattr(IntegratedFluxTransformer2DModel, "previous_residual", None) + print("TeaCache fully disabled and cache cleared.") + + def patched_inner_forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None): # Print "TeaCache is enabled!" only once per generation - if self.enable_teacache: - if not hasattr(self, "_teacache_enabled_printed"): # Check if the message has been printed - print("TeaCache is enabled!") - self._teacache_enabled_printed = True # Set flag to avoid repeated printing + if getattr(self, 'enable_teacache', False) and not hasattr(self, "_teacache_enabled_printed"): # Check if the message has been printed + print("TeaCache is enabled!") + self._teacache_enabled_printed = True # Set flag to avoid repeated printing # If TeaCache is not enabled, call the original method - if not hasattr(self, "enable_teacache") or not self.enable_teacache: + if not getattr(self, 'enable_teacache', False): return self.original_inner_forward(img, img_ids, txt, txt_ids, timesteps, y, guidance) # Get parameters from UI