apply_token_merging

This commit is contained in:
lllyasviel
2024-02-23 15:43:27 -08:00
parent 2a7fb1be24
commit bde779a526
3 changed files with 26 additions and 36 deletions

View File

@@ -633,8 +633,16 @@ def unload_model_weights(sd_model=None, info=None):
def apply_token_merging(sd_model, token_merging_ratio):
if token_merging_ratio > 0:
print('Token merging is under construction now and the setting will not take effect.')
if token_merging_ratio <= 0:
return
print(f'token_merging_ratio = {token_merging_ratio}')
from ldm_patched.contrib.external_tomesd import TomePatcher
sd_model.forge_objects.unet = TomePatcher().patch(
model=sd_model.forge_objects.unet,
ratio=token_merging_ratio
)
# TODO: rework using new UNet patcher system
return