apply_token_merging

This commit is contained in:
lllyasviel
2024-02-23 15:43:27 -08:00
parent 2a7fb1be24
commit bde779a526
3 changed files with 26 additions and 36 deletions

View File

@@ -33,6 +33,7 @@ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
from einops import repeat, rearrange
from blendmodes.blend import blendLayers, BlendType
from modules.sd_models import apply_token_merging
# some of those options should not be changed at all because they would break the model, so I removed them from options.
@@ -747,13 +748,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if k == 'sd_vae':
sd_vae.reload_vae_weights()
sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio())
res = process_images_inner(p)
finally:
sd_models.apply_token_merging(p.sd_model, 0)
# restore opts to original state
if p.override_settings_restore_afterwards:
for k, v in stored_opts.items():
@@ -1259,6 +1256,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x = self.rng.next()
self.sd_model.forge_objects = self.sd_model.forge_objects_after_applying_lora.shallow_copy()
apply_token_merging(self.sd_model, self.get_token_merging_ratio())
if self.scripts is not None:
self.scripts.process_before_every_sampling(self,
x=x,
@@ -1366,12 +1365,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
with devices.autocast():
self.calculate_hr_conds()
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
if self.scripts is not None:
self.scripts.before_hr(self)
self.sd_model.forge_objects = self.sd_model.forge_objects_after_applying_lora.shallow_copy()
apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
if self.scripts is not None:
self.scripts.process_before_every_sampling(self,
x=samples,
@@ -1385,8 +1384,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
self.sampler = None
devices.torch_gc()
@@ -1687,6 +1684,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
x *= self.initial_noise_multiplier
self.sd_model.forge_objects = self.sd_model.forge_objects_after_applying_lora.shallow_copy()
apply_token_merging(self.sd_model, self.get_token_merging_ratio())
if self.scripts is not None:
self.scripts.process_before_every_sampling(self,
x=self.init_latent,