diff --git a/scripts/teacache.py b/scripts/teacache.py index dd709c0..9274194 100644 --- a/scripts/teacache.py +++ b/scripts/teacache.py @@ -29,7 +29,7 @@ class TeaCache(scripts.Script): value=self.enable_teacache, tooltip="Enable TeaCache to speed up inference by caching intermediate results." ) - with gr.Row(): + with gr.Row(visible=False) as teacache_settings: # Hide settings by default rel_l1_thresh_slider = gr.Slider( label="Relative L1 Threshold", minimum=0.0, @@ -38,7 +38,6 @@ class TeaCache(scripts.Script): value=self.rel_l1_thresh, tooltip="Threshold for caching intermediate results. Lower values cache more aggressively." ) - with gr.Row(): steps_slider = gr.Slider( label="Steps", minimum=1, @@ -52,6 +51,17 @@ class TeaCache(scripts.Script): with gr.Row(): gr.Markdown("**Note**: Clear residual cache when changing image size or disabling TeaCache.") + # Define a function to toggle the visibility of TeaCache settings + def toggle_teacache_settings(enable_teacache): + return {teacache_settings: gr.update(visible=enable_teacache)} + + # Bind the checkbox change event to toggle settings visibility + enable_teacache_checkbox.change( + fn=toggle_teacache_settings, + inputs=enable_teacache_checkbox, + outputs=teacache_settings + ) + # Define the function to clear residual cache def clear_residual_cache_ui(): self.clear_residual_cache() @@ -88,11 +98,12 @@ class TeaCache(scripts.Script): self.clear_residual_cache() self.last_input_shape = current_input_shape - # Debug information - print("TeaCache enabled:", teacache_enabled) - print("Enable TeaCache checkbox:", enable_teacache_checkbox) - print("Relative L1 Threshold:", rel_l1_thresh_slider) - print("Steps:", steps_slider) + # Debug information only when TeaCache is enabled + if teacache_enabled: + print("TeaCache enabled:", teacache_enabled) + print("Enable TeaCache checkbox:", enable_teacache_checkbox) + print("Relative L1 Threshold:", rel_l1_thresh_slider) + print("Steps:", steps_slider) # If TeaCache is enabled, add parameters to generation parameters if teacache_enabled: @@ -233,12 +244,10 @@ def patched_inner_forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidan def patched_forward(self, x, timestep, context, y, guidance=None, **kwargs): - # Set TeaCache parameters if provided - if hasattr(self, "enable_teacache"): + # Set TeaCache parameters if provided and enabled + if hasattr(self, "enable_teacache") and kwargs.get("enable_teacache", False): self.enable_teacache = kwargs.get("enable_teacache", self.enable_teacache) - if hasattr(self, "rel_l1_thresh"): self.rel_l1_thresh = kwargs.get("rel_l1_thresh", self.rel_l1_thresh) - if hasattr(self, "steps"): self.steps = kwargs.get("steps", self.steps) # Call the original forward method