fix enable l1 even diable teacache

This commit is contained in:
likelovewant
2025-01-13 14:15:17 +08:00
parent 4b217e013f
commit 14ca779359

View File

@@ -29,7 +29,7 @@ class TeaCache(scripts.Script):
value=self.enable_teacache,
tooltip="Enable TeaCache to speed up inference by caching intermediate results."
)
with gr.Row():
with gr.Row(visible=False) as teacache_settings: # Hide settings by default
rel_l1_thresh_slider = gr.Slider(
label="Relative L1 Threshold",
minimum=0.0,
@@ -38,7 +38,6 @@ class TeaCache(scripts.Script):
value=self.rel_l1_thresh,
tooltip="Threshold for caching intermediate results. Lower values cache more aggressively."
)
with gr.Row():
steps_slider = gr.Slider(
label="Steps",
minimum=1,
@@ -52,6 +51,17 @@ class TeaCache(scripts.Script):
with gr.Row():
gr.Markdown("**Note**: Clear residual cache when changing image size or disabling TeaCache.")
# Define a function to toggle the visibility of TeaCache settings
def toggle_teacache_settings(enable_teacache):
return {teacache_settings: gr.update(visible=enable_teacache)}
# Bind the checkbox change event to toggle settings visibility
enable_teacache_checkbox.change(
fn=toggle_teacache_settings,
inputs=enable_teacache_checkbox,
outputs=teacache_settings
)
# Define the function to clear residual cache
def clear_residual_cache_ui():
self.clear_residual_cache()
@@ -88,11 +98,12 @@ class TeaCache(scripts.Script):
self.clear_residual_cache()
self.last_input_shape = current_input_shape
# Debug information
print("TeaCache enabled:", teacache_enabled)
print("Enable TeaCache checkbox:", enable_teacache_checkbox)
print("Relative L1 Threshold:", rel_l1_thresh_slider)
print("Steps:", steps_slider)
# Debug information only when TeaCache is enabled
if teacache_enabled:
print("TeaCache enabled:", teacache_enabled)
print("Enable TeaCache checkbox:", enable_teacache_checkbox)
print("Relative L1 Threshold:", rel_l1_thresh_slider)
print("Steps:", steps_slider)
# If TeaCache is enabled, add parameters to generation parameters
if teacache_enabled:
@@ -233,12 +244,10 @@ def patched_inner_forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidan
def patched_forward(self, x, timestep, context, y, guidance=None, **kwargs):
# Set TeaCache parameters if provided
if hasattr(self, "enable_teacache"):
# Set TeaCache parameters if provided and enabled
if hasattr(self, "enable_teacache") and kwargs.get("enable_teacache", False):
self.enable_teacache = kwargs.get("enable_teacache", self.enable_teacache)
if hasattr(self, "rel_l1_thresh"):
self.rel_l1_thresh = kwargs.get("rel_l1_thresh", self.rel_l1_thresh)
if hasattr(self, "steps"):
self.steps = kwargs.get("steps", self.steps)
# Call the original forward method