mirror of
https://github.com/likelovewant/sd-forge-teacache.git
synced 2026-01-26 11:09:54 +00:00
259 lines
12 KiB
Python
259 lines
12 KiB
Python
import torch
|
|
import numpy as np
|
|
from torch import Tensor
|
|
import gradio as gr
|
|
from modules import scripts
|
|
from modules.ui_components import InputAccordion
|
|
from backend.nn.flux import IntegratedFluxTransformer2DModel, timestep_embedding
|
|
|
|
|
|
class TeaCache(scripts.Script):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.enable_teacache = False
|
|
self.rel_l1_thresh = 0.4
|
|
self.steps = 25
|
|
self.last_input_shape = None # Record the last input dimensions
|
|
|
|
def title(self):
|
|
return "TeaCache"
|
|
|
|
def show(self, is_img2img):
|
|
return scripts.AlwaysVisible
|
|
|
|
def ui(self, is_img2img):
|
|
with InputAccordion(False, label=self.title(), elem_id="extensions-teacache") as teacache_enabled:
|
|
with gr.Row():
|
|
enable_teacache_checkbox = gr.Checkbox(
|
|
label="Enable TeaCache",
|
|
value=self.enable_teacache,
|
|
tooltip="Enable TeaCache to speed up inference by caching intermediate results."
|
|
)
|
|
with gr.Row(visible=False) as teacache_settings: # Hide settings by default
|
|
rel_l1_thresh_slider = gr.Slider(
|
|
label="Relative L1 Threshold",
|
|
minimum=0.0,
|
|
maximum=1.0,
|
|
step=0.01,
|
|
value=self.rel_l1_thresh,
|
|
tooltip="Threshold for caching intermediate results. Lower values cache more aggressively."
|
|
)
|
|
steps_slider = gr.Slider(
|
|
label="Steps",
|
|
minimum=1,
|
|
maximum=100,
|
|
step=1,
|
|
value=self.steps,
|
|
tooltip="Number of steps to cache intermediate results."
|
|
)
|
|
with gr.Row():
|
|
clear_cache_button = gr.Button("Clear Residual Cache", variant="secondary")
|
|
with gr.Row():
|
|
gr.Markdown("**Note**: Clear residual cache when changing image size or disabling TeaCache.")
|
|
|
|
# Define a function to toggle the visibility of TeaCache settings
|
|
def toggle_teacache_settings(enable_teacache):
|
|
return {teacache_settings: gr.update(visible=enable_teacache)}
|
|
|
|
# Bind the checkbox change event to toggle settings visibility
|
|
enable_teacache_checkbox.change(
|
|
fn=toggle_teacache_settings,
|
|
inputs=enable_teacache_checkbox,
|
|
outputs=teacache_settings
|
|
)
|
|
|
|
# Define the function to clear residual cache
|
|
def clear_residual_cache_ui():
|
|
self.clear_residual_cache()
|
|
return "Residual cache cleared."
|
|
|
|
# Bind the button click event
|
|
clear_cache_button.click(
|
|
fn=clear_residual_cache_ui,
|
|
outputs=None
|
|
)
|
|
|
|
# Return UI components
|
|
return [teacache_enabled, enable_teacache_checkbox, rel_l1_thresh_slider, steps_slider]
|
|
|
|
def clear_residual_cache(self):
|
|
"""Clear residual cache and free GPU memory."""
|
|
if hasattr(IntegratedFluxTransformer2DModel, "previous_residual"):
|
|
setattr(IntegratedFluxTransformer2DModel, "previous_residual", None)
|
|
if hasattr(IntegratedFluxTransformer2DModel, "previous_modulated_input"):
|
|
setattr(IntegratedFluxTransformer2DModel, "previous_modulated_input", None)
|
|
if hasattr(IntegratedFluxTransformer2DModel, "cnt"):
|
|
setattr(IntegratedFluxTransformer2DModel, "cnt", 0)
|
|
if hasattr(IntegratedFluxTransformer2DModel, "accumulated_rel_l1_distance"):
|
|
setattr(IntegratedFluxTransformer2DModel, "accumulated_rel_l1_distance", 0)
|
|
if hasattr(IntegratedFluxTransformer2DModel, "_teacache_enabled_printed"):
|
|
delattr(IntegratedFluxTransformer2DModel, "_teacache_enabled_printed")
|
|
# Free GPU memory
|
|
torch.cuda.empty_cache()
|
|
print("Residual cache cleared and GPU memory freed.")
|
|
|
|
def process(self, p, teacache_enabled, enable_teacache_checkbox, rel_l1_thresh_slider, steps_slider):
|
|
# Get the current input dimensions
|
|
current_input_shape = (p.width, p.height)
|
|
if self.last_input_shape is not None and current_input_shape != self.last_input_shape:
|
|
# If input dimensions change, clear residual cache
|
|
self.clear_residual_cache()
|
|
self.last_input_shape = current_input_shape
|
|
|
|
# Debug information only when TeaCache is enabled
|
|
if teacache_enabled:
|
|
print("TeaCache enabled:", teacache_enabled)
|
|
print("Enable TeaCache checkbox:", enable_teacache_checkbox)
|
|
print("Relative L1 Threshold:", rel_l1_thresh_slider)
|
|
print("Steps:", steps_slider)
|
|
|
|
# If TeaCache is enabled, add parameters to generation parameters
|
|
if teacache_enabled:
|
|
p.extra_generation_params.update({
|
|
"enable_teacache": enable_teacache_checkbox,
|
|
"rel_l1_thresh": rel_l1_thresh_slider,
|
|
"steps": steps_slider,
|
|
})
|
|
|
|
# Dynamically modify class attributes of IntegratedFluxTransformer2DModel
|
|
setattr(IntegratedFluxTransformer2DModel, "enable_teacache", enable_teacache_checkbox)
|
|
setattr(IntegratedFluxTransformer2DModel, "cnt", 0)
|
|
setattr(IntegratedFluxTransformer2DModel, "rel_l1_thresh", rel_l1_thresh_slider)
|
|
setattr(IntegratedFluxTransformer2DModel, "steps", steps_slider)
|
|
setattr(IntegratedFluxTransformer2DModel, "accumulated_rel_l1_distance", 0)
|
|
setattr(IntegratedFluxTransformer2DModel, "previous_modulated_input", None)
|
|
setattr(IntegratedFluxTransformer2DModel, "previous_residual", None)
|
|
|
|
# Replace the original inner_forward method
|
|
if hasattr(IntegratedFluxTransformer2DModel, "inner_forward"):
|
|
# Save the original inner_forward method
|
|
setattr(IntegratedFluxTransformer2DModel, "original_inner_forward", IntegratedFluxTransformer2DModel.inner_forward)
|
|
# Replace with the new inner_forward method
|
|
IntegratedFluxTransformer2DModel.inner_forward = patched_inner_forward
|
|
|
|
# Replace the original forward method
|
|
if not hasattr(IntegratedFluxTransformer2DModel, "original_forward"):
|
|
IntegratedFluxTransformer2DModel.original_forward = IntegratedFluxTransformer2DModel.forward
|
|
IntegratedFluxTransformer2DModel.forward = patched_forward
|
|
else:
|
|
# If TeaCache is disabled, restore the original inner_forward method
|
|
if hasattr(IntegratedFluxTransformer2DModel, "original_inner_forward"):
|
|
IntegratedFluxTransformer2DModel.inner_forward = IntegratedFluxTransformer2DModel.original_inner_forward
|
|
# Clear residual cache and reset TeaCache attributes
|
|
self.clear_residual_cache()
|
|
setattr(IntegratedFluxTransformer2DModel, "enable_teacache", False)
|
|
setattr(IntegratedFluxTransformer2DModel, "cnt", 0)
|
|
setattr(IntegratedFluxTransformer2DModel, "rel_l1_thresh", 0.4)
|
|
setattr(IntegratedFluxTransformer2DModel, "steps", 25)
|
|
setattr(IntegratedFluxTransformer2DModel, "accumulated_rel_l1_distance", 0)
|
|
setattr(IntegratedFluxTransformer2DModel, "previous_modulated_input", None)
|
|
setattr(IntegratedFluxTransformer2DModel, "previous_residual", None)
|
|
print("TeaCache fully disabled and cache cleared.")
|
|
|
|
# Restore the original forward method
|
|
if hasattr(IntegratedFluxTransformer2DModel, "original_forward"):
|
|
IntegratedFluxTransformer2DModel.forward = IntegratedFluxTransformer2DModel.original_forward
|
|
# Remove the patched forward method to avoid recursion
|
|
if hasattr(IntegratedFluxTransformer2DModel, "patched_forward"):
|
|
delattr(IntegratedFluxTransformer2DModel, "patched_forward")
|
|
|
|
|
|
def patched_inner_forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None):
|
|
# Print "TeaCache is enabled!" only once per generation
|
|
if self.enable_teacache:
|
|
if not hasattr(self, "_teacache_enabled_printed"): # Check if the message has been printed
|
|
print("TeaCache is enabled!")
|
|
self._teacache_enabled_printed = True # Set flag to avoid repeated printing
|
|
|
|
# If TeaCache is not enabled, call the original method
|
|
if not hasattr(self, "enable_teacache") or not self.enable_teacache:
|
|
return self.original_inner_forward(img, img_ids, txt, txt_ids, timesteps, y, guidance)
|
|
|
|
# Get parameters from UI
|
|
rel_l1_thresh = getattr(self, "rel_l1_thresh", 0.4)
|
|
steps = getattr(self, "steps", 25)
|
|
|
|
# TeaCache logic
|
|
if img.ndim != 3 or txt.ndim != 3:
|
|
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
|
|
|
# Image and text embedding
|
|
img = self.img_in(img)
|
|
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
|
|
|
|
# If guidance_embed is enabled, add guidance information
|
|
if self.guidance_embed:
|
|
if guidance is None:
|
|
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
|
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
|
|
|
vec = vec + self.vector_in(y)
|
|
txt = self.txt_in(txt)
|
|
|
|
# Merge image and text IDs
|
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
|
pe = self.pe_embedder(ids)
|
|
|
|
# TeaCache logic
|
|
inp = img.clone()
|
|
vec_ = vec.clone()
|
|
modulated_inp = img.clone() # Use img.clone() directly as modulated_inp
|
|
|
|
# Check if previous_modulated_input exists and has the correct shape
|
|
if hasattr(self, "previous_modulated_input") and self.previous_modulated_input is not None:
|
|
if self.previous_modulated_input.shape != modulated_inp.shape:
|
|
# Clear cache if shapes don't match
|
|
self.previous_modulated_input = None
|
|
self.previous_residual = None
|
|
print(" Cleared cache due to shape mismatch.")
|
|
|
|
if self.cnt == 0 or self.cnt == steps - 1:
|
|
should_calc = True
|
|
self.accumulated_rel_l1_distance = 0
|
|
else:
|
|
if hasattr(self, "previous_modulated_input") and self.previous_modulated_input is not None:
|
|
coefficients = [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01]
|
|
rescale_func = np.poly1d(coefficients)
|
|
self.accumulated_rel_l1_distance += rescale_func(
|
|
((modulated_inp - self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()
|
|
)
|
|
if self.accumulated_rel_l1_distance < rel_l1_thresh:
|
|
should_calc = False
|
|
else:
|
|
should_calc = True
|
|
self.accumulated_rel_l1_distance = 0
|
|
else:
|
|
should_calc = True
|
|
|
|
self.previous_modulated_input = modulated_inp
|
|
self.cnt += 1
|
|
if self.cnt == steps:
|
|
self.cnt = 0
|
|
|
|
if not should_calc:
|
|
if hasattr(self, "previous_residual") and self.previous_residual is not None:
|
|
img += self.previous_residual
|
|
else:
|
|
ori_img = img.clone()
|
|
for block in self.double_blocks:
|
|
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
|
img = torch.cat((txt, img), 1)
|
|
for block in self.single_blocks:
|
|
img = block(img, vec=vec, pe=pe)
|
|
img = img[:, txt.shape[1]:, ...]
|
|
self.previous_residual = img - ori_img
|
|
|
|
# Final output
|
|
img = self.final_layer(img, vec)
|
|
return img
|
|
|
|
|
|
def patched_forward(self, x, timestep, context, y, guidance=None, **kwargs):
|
|
# Set TeaCache parameters if provided and enabled
|
|
if hasattr(self, "enable_teacache") and kwargs.get("enable_teacache", False):
|
|
self.enable_teacache = kwargs.get("enable_teacache", self.enable_teacache)
|
|
self.rel_l1_thresh = kwargs.get("rel_l1_thresh", self.rel_l1_thresh)
|
|
self.steps = kwargs.get("steps", self.steps)
|
|
|
|
# Call the original forward method
|
|
return self.original_forward(x, timestep, context, y, guidance, **kwargs) |