Update sd_hijack.py

This commit is contained in:
lllyasviel
2024-01-25 05:05:23 -08:00
parent 231b860e92
commit 7273d9b89f

View File

@@ -136,25 +136,10 @@ class StableDiffusionModelHijack:
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
def apply_optimizations(self, option=None):
try:
self.optimization_method = apply_optimizations(option)
except Exception as e:
errors.display(e, "applying cross attention optimization")
undo_optimizations()
pass
def convert_sdxl_to_ssd(self, m):
"""Converts an SDXL model to a Segmind Stable Diffusion model (see https://huggingface.co/segmind/SSD-1B)"""
delattr(m.model.diffusion_model.middle_block, '1')
delattr(m.model.diffusion_model.middle_block, '2')
for i in ['9', '8', '7', '6', '5', '4']:
delattr(m.model.diffusion_model.input_blocks[7][1].transformer_blocks, i)
delattr(m.model.diffusion_model.input_blocks[8][1].transformer_blocks, i)
delattr(m.model.diffusion_model.output_blocks[0][1].transformer_blocks, i)
delattr(m.model.diffusion_model.output_blocks[1][1].transformer_blocks, i)
delattr(m.model.diffusion_model.output_blocks[4][1].transformer_blocks, '1')
delattr(m.model.diffusion_model.output_blocks[5][1].transformer_blocks, '1')
devices.torch_gc()
pass
def hijack(self, m):
conditioner = getattr(m, 'conditioner', None)
@@ -199,8 +184,6 @@ class StableDiffusionModelHijack:
if m.cond_stage_key == "edit":
sd_hijack_unet.hijack_ddpm_edit()
self.apply_optimizations()
self.clip = m.cond_stage_model
def flatten(el):
@@ -223,11 +206,9 @@ class StableDiffusionModelHijack:
else:
sd_unet.original_forward = None
def undo_hijack(self, m):
pass
def apply_circular(self, enable):
if self.circular_enabled == enable:
return