mirror of
https://github.com/likelovewant/sd-forge-teacache.git
synced 2026-01-26 11:09:54 +00:00
Fix enabling and disabling
Remove unnecessary cache cleanup button
This commit is contained in:
@@ -22,14 +22,8 @@ class TeaCache(scripts.Script):
|
||||
return scripts.AlwaysVisible
|
||||
|
||||
def ui(self, is_img2img):
|
||||
with InputAccordion(False, label=self.title(), elem_id="extensions-teacache") as teacache_enabled:
|
||||
with gr.Row():
|
||||
enable_teacache_checkbox = gr.Checkbox(
|
||||
label="Enable TeaCache",
|
||||
value=self.enable_teacache,
|
||||
tooltip="Enable TeaCache to speed up inference by caching intermediate results."
|
||||
)
|
||||
with gr.Row(visible=False) as teacache_settings: # Hide settings by default
|
||||
with InputAccordion(value=False, label=self.title()) as enable:
|
||||
with gr.Group(elem_classes="teacache"):
|
||||
rel_l1_thresh_slider = gr.Slider(
|
||||
label="Relative L1 Threshold",
|
||||
minimum=0.0,
|
||||
@@ -46,35 +40,20 @@ class TeaCache(scripts.Script):
|
||||
value=self.steps,
|
||||
tooltip="Number of steps to cache intermediate results."
|
||||
)
|
||||
with gr.Row():
|
||||
clear_cache_button = gr.Button("Clear Residual Cache", variant="secondary")
|
||||
with gr.Row():
|
||||
gr.Markdown("**Note**: Clear residual cache when changing image size or disabling TeaCache.")
|
||||
|
||||
# Define a function to toggle the visibility of TeaCache settings
|
||||
def toggle_teacache_settings(enable_teacache):
|
||||
return {teacache_settings: gr.update(visible=enable_teacache)}
|
||||
|
||||
# Bind the checkbox change event to toggle settings visibility
|
||||
enable_teacache_checkbox.change(
|
||||
fn=toggle_teacache_settings,
|
||||
inputs=enable_teacache_checkbox,
|
||||
outputs=teacache_settings
|
||||
)
|
||||
self.paste_field_names = []
|
||||
self.infotext_fields = [
|
||||
(enable, "TeaCache Enabled"),
|
||||
(rel_l1_thresh_slider, "TeaCache Relative L1 Threshold"),
|
||||
(steps_slider, "TeaCache Steps"),
|
||||
]
|
||||
|
||||
# Define the function to clear residual cache
|
||||
def clear_residual_cache_ui():
|
||||
self.clear_residual_cache()
|
||||
return "Residual cache cleared."
|
||||
for comp, name in self.infotext_fields:
|
||||
comp.do_not_save_to_config = True
|
||||
self.paste_field_names.append(name)
|
||||
|
||||
# Bind the button click event
|
||||
clear_cache_button.click(
|
||||
fn=clear_residual_cache_ui,
|
||||
outputs=None
|
||||
)
|
||||
|
||||
# Return UI components
|
||||
return [teacache_enabled, enable_teacache_checkbox, rel_l1_thresh_slider, steps_slider]
|
||||
return [enable, rel_l1_thresh_slider, steps_slider]
|
||||
|
||||
def clear_residual_cache(self):
|
||||
"""Clear residual cache and free GPU memory."""
|
||||
@@ -92,7 +71,7 @@ class TeaCache(scripts.Script):
|
||||
torch.cuda.empty_cache()
|
||||
print("Residual cache cleared and GPU memory freed.")
|
||||
|
||||
def process(self, p, teacache_enabled, enable_teacache_checkbox, rel_l1_thresh_slider, steps_slider):
|
||||
def process(self, p, enable, rel_l1_thresh_slider, steps_slider):
|
||||
# Get the current input dimensions
|
||||
current_input_shape = (p.width, p.height)
|
||||
if self.last_input_shape is not None and current_input_shape != self.last_input_shape:
|
||||
@@ -101,22 +80,21 @@ class TeaCache(scripts.Script):
|
||||
self.last_input_shape = current_input_shape
|
||||
|
||||
# Debug information only when TeaCache is enabled
|
||||
if teacache_enabled:
|
||||
print("TeaCache enabled:", teacache_enabled)
|
||||
print("Enable TeaCache checkbox:", enable_teacache_checkbox)
|
||||
if enable:
|
||||
print("TeaCache enabled:", enable)
|
||||
print("Relative L1 Threshold:", rel_l1_thresh_slider)
|
||||
print("Steps:", steps_slider)
|
||||
|
||||
# If TeaCache is enabled, add parameters to generation parameters
|
||||
if teacache_enabled:
|
||||
if enable:
|
||||
p.extra_generation_params.update({
|
||||
"enable_teacache": enable_teacache_checkbox,
|
||||
"enable_teacache": enable,
|
||||
"rel_l1_thresh": rel_l1_thresh_slider,
|
||||
"steps": steps_slider,
|
||||
})
|
||||
|
||||
# Dynamically modify class attributes of IntegratedFluxTransformer2DModel
|
||||
setattr(IntegratedFluxTransformer2DModel, "enable_teacache", enable_teacache_checkbox)
|
||||
setattr(IntegratedFluxTransformer2DModel, "enable_teacache", enable)
|
||||
setattr(IntegratedFluxTransformer2DModel, "cnt", 0)
|
||||
setattr(IntegratedFluxTransformer2DModel, "rel_l1_thresh", rel_l1_thresh_slider)
|
||||
setattr(IntegratedFluxTransformer2DModel, "steps", steps_slider)
|
||||
@@ -125,48 +103,47 @@ class TeaCache(scripts.Script):
|
||||
setattr(IntegratedFluxTransformer2DModel, "previous_residual", None)
|
||||
|
||||
# Replace the original inner_forward method
|
||||
if hasattr(IntegratedFluxTransformer2DModel, "inner_forward"):
|
||||
# Save the original inner_forward method
|
||||
if not hasattr(IntegratedFluxTransformer2DModel, "original_inner_forward"):
|
||||
setattr(IntegratedFluxTransformer2DModel, "original_inner_forward", IntegratedFluxTransformer2DModel.inner_forward)
|
||||
# Replace with the new inner_forward method
|
||||
IntegratedFluxTransformer2DModel.inner_forward = patched_inner_forward
|
||||
IntegratedFluxTransformer2DModel.inner_forward = patched_inner_forward
|
||||
|
||||
# Replace the original forward method
|
||||
if not hasattr(IntegratedFluxTransformer2DModel, "original_forward"):
|
||||
IntegratedFluxTransformer2DModel.original_forward = IntegratedFluxTransformer2DModel.forward
|
||||
IntegratedFluxTransformer2DModel.forward = patched_forward
|
||||
IntegratedFluxTransformer2DModel.forward = patched_forward
|
||||
setattr(IntegratedFluxTransformer2DModel, "patched_forward", patched_forward) # Mark that we have replaced the forward method
|
||||
else:
|
||||
# If TeaCache is disabled, restore the original inner_forward method
|
||||
if hasattr(IntegratedFluxTransformer2DModel, "original_inner_forward"):
|
||||
IntegratedFluxTransformer2DModel.inner_forward = IntegratedFluxTransformer2DModel.original_inner_forward
|
||||
# Clear residual cache and reset TeaCache attributes
|
||||
self.clear_residual_cache()
|
||||
setattr(IntegratedFluxTransformer2DModel, "enable_teacache", False)
|
||||
setattr(IntegratedFluxTransformer2DModel, "cnt", 0)
|
||||
setattr(IntegratedFluxTransformer2DModel, "rel_l1_thresh", 0.4)
|
||||
setattr(IntegratedFluxTransformer2DModel, "steps", 25)
|
||||
setattr(IntegratedFluxTransformer2DModel, "accumulated_rel_l1_distance", 0)
|
||||
setattr(IntegratedFluxTransformer2DModel, "previous_modulated_input", None)
|
||||
setattr(IntegratedFluxTransformer2DModel, "previous_residual", None)
|
||||
print("TeaCache fully disabled and cache cleared.")
|
||||
|
||||
# Restore the original forward method
|
||||
if hasattr(IntegratedFluxTransformer2DModel, "original_forward"):
|
||||
# Restore the original forward method only if it was replaced and remove patched_forward flag
|
||||
if hasattr(IntegratedFluxTransformer2DModel, "original_forward") and hasattr(IntegratedFluxTransformer2DModel, "patched_forward"):
|
||||
IntegratedFluxTransformer2DModel.forward = IntegratedFluxTransformer2DModel.original_forward
|
||||
# Remove the patched forward method to avoid recursion
|
||||
if hasattr(IntegratedFluxTransformer2DModel, "patched_forward"):
|
||||
delattr(IntegratedFluxTransformer2DModel, "patched_forward")
|
||||
delattr(IntegratedFluxTransformer2DModel, "patched_forward")
|
||||
|
||||
# Clear residual cache and reset TeaCache attributes
|
||||
self.clear_residual_cache()
|
||||
setattr(IntegratedFluxTransformer2DModel, "enable_teacache", False)
|
||||
setattr(IntegratedFluxTransformer2DModel, "cnt", 0)
|
||||
setattr(IntegratedFluxTransformer2DModel, "rel_l1_thresh", 0.4)
|
||||
setattr(IntegratedFluxTransformer2DModel, "steps", 25)
|
||||
setattr(IntegratedFluxTransformer2DModel, "accumulated_rel_l1_distance", 0)
|
||||
setattr(IntegratedFluxTransformer2DModel, "previous_modulated_input", None)
|
||||
setattr(IntegratedFluxTransformer2DModel, "previous_residual", None)
|
||||
print("TeaCache fully disabled and cache cleared.")
|
||||
|
||||
|
||||
|
||||
|
||||
def patched_inner_forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None):
|
||||
# Print "TeaCache is enabled!" only once per generation
|
||||
if self.enable_teacache:
|
||||
if not hasattr(self, "_teacache_enabled_printed"): # Check if the message has been printed
|
||||
print("TeaCache is enabled!")
|
||||
self._teacache_enabled_printed = True # Set flag to avoid repeated printing
|
||||
if getattr(self, 'enable_teacache', False) and not hasattr(self, "_teacache_enabled_printed"): # Check if the message has been printed
|
||||
print("TeaCache is enabled!")
|
||||
self._teacache_enabled_printed = True # Set flag to avoid repeated printing
|
||||
|
||||
# If TeaCache is not enabled, call the original method
|
||||
if not hasattr(self, "enable_teacache") or not self.enable_teacache:
|
||||
if not getattr(self, 'enable_teacache', False):
|
||||
return self.original_inner_forward(img, img_ids, txt, txt_ids, timesteps, y, guidance)
|
||||
|
||||
# Get parameters from UI
|
||||
|
||||
Reference in New Issue
Block a user