fix RecursionError when enable and disable the teacache

This commit is contained in:
likelovewant
2025-01-13 14:39:16 +08:00
parent 14ca779359
commit e2ba73750c
2 changed files with 5 additions and 5 deletions

View File

@@ -1,4 +1,3 @@
<<<<<<< HEAD
## Sd-Forge-TeaCache: Speed up Your Diffusion Models ## Sd-Forge-TeaCache: Speed up Your Diffusion Models
**Introduction** **Introduction**
@@ -57,7 +56,3 @@ For additional information and other integrations, explore:
* [ComfyUI-TeaCache](https://github.com/welltop-cn/ComfyUI-TeaCache) * [ComfyUI-TeaCache](https://github.com/welltop-cn/ComfyUI-TeaCache)
=======
# sd-forge-teacache
teacache adaption on forge webui
>>>>>>> 45804d1411ad9da74bda8e26d6c7a6d2068c579c

View File

@@ -86,6 +86,8 @@ class TeaCache(scripts.Script):
setattr(IntegratedFluxTransformer2DModel, "cnt", 0) setattr(IntegratedFluxTransformer2DModel, "cnt", 0)
if hasattr(IntegratedFluxTransformer2DModel, "accumulated_rel_l1_distance"): if hasattr(IntegratedFluxTransformer2DModel, "accumulated_rel_l1_distance"):
setattr(IntegratedFluxTransformer2DModel, "accumulated_rel_l1_distance", 0) setattr(IntegratedFluxTransformer2DModel, "accumulated_rel_l1_distance", 0)
if hasattr(IntegratedFluxTransformer2DModel, "_teacache_enabled_printed"):
delattr(IntegratedFluxTransformer2DModel, "_teacache_enabled_printed")
# Free GPU memory # Free GPU memory
torch.cuda.empty_cache() torch.cuda.empty_cache()
print("Residual cache cleared and GPU memory freed.") print("Residual cache cleared and GPU memory freed.")
@@ -151,6 +153,9 @@ class TeaCache(scripts.Script):
# Restore the original forward method # Restore the original forward method
if hasattr(IntegratedFluxTransformer2DModel, "original_forward"): if hasattr(IntegratedFluxTransformer2DModel, "original_forward"):
IntegratedFluxTransformer2DModel.forward = IntegratedFluxTransformer2DModel.original_forward IntegratedFluxTransformer2DModel.forward = IntegratedFluxTransformer2DModel.original_forward
# Remove the patched forward method to avoid recursion
if hasattr(IntegratedFluxTransformer2DModel, "patched_forward"):
delattr(IntegratedFluxTransformer2DModel, "patched_forward")
def patched_inner_forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None): def patched_inner_forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None):