Initial commit of teacha for forge webui

This commit is contained in:
likelovewant
2025-01-12 20:59:17 +08:00
commit 119c2ae21a
2 changed files with 303 additions and 0 deletions

58
README.md Normal file
View File

@@ -0,0 +1,58 @@
## Sd-Forge-TeaCache: Speed up Your Diffusion Models
**Introduction**
Timestep Embedding Aware Cache (TeaCache) is a revolutionary training-free caching approach that leverages the
fluctuating differences between model outputs across timesteps. This acceleration technique significantly boosts
inference speed for various diffusion models, including Image, Video, and Audio.
TeaCache's integration into SD Forge WebUI for Flux only. Installation is as
straightforward as any other extension:
* **Clone:** `git clone https://github.com/likelovewant/sd-forge-teacache.git`
into extensions directory ,relauch the system .
**Speed Up Your Diffusion Generation**
TeaCache can accelerate FLUX inference by up to 2x with minimal visual quality degradation, all without requiring any training.
Within the Forge WebUI, you can easily adjust the following settings:
* **Relative L1 Threshold:** Controls the sensitivity of TeaCache's caching mechanism.
* **Steps:** Matches the number of sampling steps used in TeaCache.
**Performance Tuning**
Based on [TeaCache4FLUX](https://github.com/ali-vilab/TeaCache/tree/main/TeaCache4FLUX), you can achieve different
speedups:
* 0.25 threshold for 1.5x speedup
* 0.4 threshold for 1.8x speedup
* 0.6 threshold for 2.0x speedup
* 0.8 threshold for 2.25x speedup
**Important Notes:**
* **Maintain Consistency:** Keep the sampling steps in TeaCache aligned with the steps used in your Flux Sampling steps .Discrepancies can lead to lower quality outputs.
* **LoRA Considerations:** When utilizing LoRAs, adjust the steps or scales based on your GPU's capabilities. A recommended starting point is 28 steps or more.
To ensure smooth operation, remember to:
1. **Clear Residual Cache (optional):** When changing image sizes or disabling the TeaCache extension, always click "Clear Residual Cache" within the Forge WebUI. This prevents potential conflicts and maintains optimal performance.
2. **Disable TeaCache Properly:** Ensure disable the TeaCache extension if you don't need it in your Forge WebUI. If not proper `Clear Residual Cache`, you may encounter unexpected behavior and require a full relaunch.
Several AI assistants has assisting with code generation and refinement for this extension based on the below resources.
**Credits and Resources**
This adaptation leverages [TeaCache4FLUX](https://github.com/ali-vilab/TeaCache/tree/main/TeaCache4FLUX) and
builds upon the foundational work of the original TeaCache repository:
[https://github.com/ali-vilab/TeaCache](https://github.com/ali-vilab/TeaCache).
For additional information and other integrations, explore:
* [ComfyUI-TeaCache](https://github.com/welltop-cn/ComfyUI-TeaCache)

245
scripts/teacache.py Normal file
View File

@@ -0,0 +1,245 @@
import torch
import numpy as np
from torch import Tensor
import gradio as gr
from modules import scripts
from modules.ui_components import InputAccordion
from backend.nn.flux import IntegratedFluxTransformer2DModel, timestep_embedding
class TeaCache(scripts.Script):
def __init__(self):
super().__init__()
self.enable_teacache = False
self.rel_l1_thresh = 0.4
self.steps = 25
self.last_input_shape = None # Record the last input dimensions
def title(self):
return "TeaCache"
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
with InputAccordion(False, label=self.title(), elem_id="extensions-teacache") as teacache_enabled:
with gr.Row():
enable_teacache_checkbox = gr.Checkbox(
label="Enable TeaCache",
value=self.enable_teacache,
tooltip="Enable TeaCache to speed up inference by caching intermediate results."
)
with gr.Row():
rel_l1_thresh_slider = gr.Slider(
label="Relative L1 Threshold",
minimum=0.0,
maximum=1.0,
step=0.01,
value=self.rel_l1_thresh,
tooltip="Threshold for caching intermediate results. Lower values cache more aggressively."
)
with gr.Row():
steps_slider = gr.Slider(
label="Steps",
minimum=1,
maximum=100,
step=1,
value=self.steps,
tooltip="Number of steps to cache intermediate results."
)
with gr.Row():
clear_cache_button = gr.Button("Clear Residual Cache", variant="secondary")
with gr.Row():
gr.Markdown("**Note**: Clear residual cache when changing image size or disabling TeaCache.")
# Define the function to clear residual cache
def clear_residual_cache_ui():
self.clear_residual_cache()
return "Residual cache cleared."
# Bind the button click event
clear_cache_button.click(
fn=clear_residual_cache_ui,
outputs=None
)
# Return UI components
return [teacache_enabled, enable_teacache_checkbox, rel_l1_thresh_slider, steps_slider]
def clear_residual_cache(self):
"""Clear residual cache and free GPU memory."""
if hasattr(IntegratedFluxTransformer2DModel, "previous_residual"):
setattr(IntegratedFluxTransformer2DModel, "previous_residual", None)
if hasattr(IntegratedFluxTransformer2DModel, "previous_modulated_input"):
setattr(IntegratedFluxTransformer2DModel, "previous_modulated_input", None)
if hasattr(IntegratedFluxTransformer2DModel, "cnt"):
setattr(IntegratedFluxTransformer2DModel, "cnt", 0)
if hasattr(IntegratedFluxTransformer2DModel, "accumulated_rel_l1_distance"):
setattr(IntegratedFluxTransformer2DModel, "accumulated_rel_l1_distance", 0)
# Free GPU memory
torch.cuda.empty_cache()
print("Residual cache cleared and GPU memory freed.")
def process(self, p, teacache_enabled, enable_teacache_checkbox, rel_l1_thresh_slider, steps_slider):
# Get the current input dimensions
current_input_shape = (p.width, p.height)
if self.last_input_shape is not None and current_input_shape != self.last_input_shape:
# If input dimensions change, clear residual cache
self.clear_residual_cache()
self.last_input_shape = current_input_shape
# Debug information
print("TeaCache enabled:", teacache_enabled)
print("Enable TeaCache checkbox:", enable_teacache_checkbox)
print("Relative L1 Threshold:", rel_l1_thresh_slider)
print("Steps:", steps_slider)
# If TeaCache is enabled, add parameters to generation parameters
if teacache_enabled:
p.extra_generation_params.update({
"enable_teacache": enable_teacache_checkbox,
"rel_l1_thresh": rel_l1_thresh_slider,
"steps": steps_slider,
})
# Dynamically modify class attributes of IntegratedFluxTransformer2DModel
setattr(IntegratedFluxTransformer2DModel, "enable_teacache", enable_teacache_checkbox)
setattr(IntegratedFluxTransformer2DModel, "cnt", 0)
setattr(IntegratedFluxTransformer2DModel, "rel_l1_thresh", rel_l1_thresh_slider)
setattr(IntegratedFluxTransformer2DModel, "steps", steps_slider)
setattr(IntegratedFluxTransformer2DModel, "accumulated_rel_l1_distance", 0)
setattr(IntegratedFluxTransformer2DModel, "previous_modulated_input", None)
setattr(IntegratedFluxTransformer2DModel, "previous_residual", None)
# Replace the original inner_forward method
if hasattr(IntegratedFluxTransformer2DModel, "inner_forward"):
# Save the original inner_forward method
setattr(IntegratedFluxTransformer2DModel, "original_inner_forward", IntegratedFluxTransformer2DModel.inner_forward)
# Replace with the new inner_forward method
IntegratedFluxTransformer2DModel.inner_forward = patched_inner_forward
# Replace the original forward method
if not hasattr(IntegratedFluxTransformer2DModel, "original_forward"):
IntegratedFluxTransformer2DModel.original_forward = IntegratedFluxTransformer2DModel.forward
IntegratedFluxTransformer2DModel.forward = patched_forward
else:
# If TeaCache is disabled, restore the original inner_forward method
if hasattr(IntegratedFluxTransformer2DModel, "original_inner_forward"):
IntegratedFluxTransformer2DModel.inner_forward = IntegratedFluxTransformer2DModel.original_inner_forward
# Clear residual cache and reset TeaCache attributes
self.clear_residual_cache()
setattr(IntegratedFluxTransformer2DModel, "enable_teacache", False)
setattr(IntegratedFluxTransformer2DModel, "cnt", 0)
setattr(IntegratedFluxTransformer2DModel, "rel_l1_thresh", 0.4)
setattr(IntegratedFluxTransformer2DModel, "steps", 25)
setattr(IntegratedFluxTransformer2DModel, "accumulated_rel_l1_distance", 0)
setattr(IntegratedFluxTransformer2DModel, "previous_modulated_input", None)
setattr(IntegratedFluxTransformer2DModel, "previous_residual", None)
print("TeaCache fully disabled and cache cleared.")
# Restore the original forward method
if hasattr(IntegratedFluxTransformer2DModel, "original_forward"):
IntegratedFluxTransformer2DModel.forward = IntegratedFluxTransformer2DModel.original_forward
def patched_inner_forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None):
# Print "TeaCache is enabled!" only once per generation
if self.enable_teacache:
if not hasattr(self, "_teacache_enabled_printed"): # Check if the message has been printed
print("TeaCache is enabled!")
self._teacache_enabled_printed = True # Set flag to avoid repeated printing
# If TeaCache is not enabled, call the original method
if not hasattr(self, "enable_teacache") or not self.enable_teacache:
return self.original_inner_forward(img, img_ids, txt, txt_ids, timesteps, y, guidance)
# Get parameters from UI
rel_l1_thresh = getattr(self, "rel_l1_thresh", 0.4)
steps = getattr(self, "steps", 25)
# TeaCache logic
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# Image and text embedding
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
# If guidance_embed is enabled, add guidance information
if self.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
# Merge image and text IDs
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
# TeaCache logic
inp = img.clone()
vec_ = vec.clone()
modulated_inp = img.clone() # Use img.clone() directly as modulated_inp
# Check if previous_modulated_input exists and has the correct shape
if hasattr(self, "previous_modulated_input") and self.previous_modulated_input is not None:
if self.previous_modulated_input.shape != modulated_inp.shape:
# Clear cache if shapes don't match
self.previous_modulated_input = None
self.previous_residual = None
print(" Cleared cache due to shape mismatch.")
if self.cnt == 0 or self.cnt == steps - 1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
if hasattr(self, "previous_modulated_input") and self.previous_modulated_input is not None:
coefficients = [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01]
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(
((modulated_inp - self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()
)
if self.accumulated_rel_l1_distance < rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
should_calc = True
self.previous_modulated_input = modulated_inp
self.cnt += 1
if self.cnt == steps:
self.cnt = 0
if not should_calc:
if hasattr(self, "previous_residual") and self.previous_residual is not None:
img += self.previous_residual
else:
ori_img = img.clone()
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
img = img[:, txt.shape[1]:, ...]
self.previous_residual = img - ori_img
# Final output
img = self.final_layer(img, vec)
return img
def patched_forward(self, x, timestep, context, y, guidance=None, **kwargs):
# Set TeaCache parameters if provided
if hasattr(self, "enable_teacache"):
self.enable_teacache = kwargs.get("enable_teacache", self.enable_teacache)
if hasattr(self, "rel_l1_thresh"):
self.rel_l1_thresh = kwargs.get("rel_l1_thresh", self.rel_l1_thresh)
if hasattr(self, "steps"):
self.steps = kwargs.get("steps", self.steps)
# Call the original forward method
return self.original_forward(x, timestep, context, y, guidance, **kwargs)