VAE patcher and more types of unet patches

This commit is contained in:
lllyasviel
2024-02-29 22:37:34 -08:00
committed by GitHub
parent b59deaa382
commit ef35383b4a
6 changed files with 81 additions and 10 deletions

View File

@@ -75,6 +75,12 @@ class ModelPatcher:
def set_model_unet_function_wrapper(self, unet_wrapper_function):
self.model_options["model_function_wrapper"] = unet_wrapper_function
def set_model_vae_encode_wrapper(self, wrapper_function):
self.model_options["model_vae_encode_wrapper"] = wrapper_function
def set_model_vae_decode_wrapper(self, wrapper_function):
self.model_options["model_vae_decode_wrapper"] = wrapper_function
def set_model_patch(self, patch, name):
to = self.model_options["transformer_options"]
if "patches" not in to:
@@ -242,7 +248,17 @@ class ModelPatcher:
w1 = v[0]
if alpha != 0.0:
if w1.shape != weight.shape:
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
if w1.ndim == weight.ndim == 4:
new_shape = [max(n, m) for n, m in zip(weight.shape, w1.shape)]
print(f'Merged with {key} channel changed to {new_shape}')
new_diff = alpha * ldm_patched.modules.model_management.cast_to_device(w1, weight.device, weight.dtype)
new_weight = torch.zeros(size=new_shape).to(weight)
new_weight[:weight.shape[0], :weight.shape[1], :weight.shape[2], :weight.shape[3]] = weight
new_weight[:new_diff.shape[0], :new_diff.shape[1], :new_diff.shape[2], :new_diff.shape[3]] += new_diff
new_weight = new_weight.contiguous().clone()
weight = new_weight
else:
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else:
weight += alpha * ldm_patched.modules.model_management.cast_to_device(w1, weight.device, weight.dtype)
elif patch_type == "lora": #lora/locon

View File

@@ -163,7 +163,10 @@ class CLIP:
return self.patcher.get_key_patches()
class VAE:
def __init__(self, sd=None, device=None, config=None, dtype=None):
def __init__(self, sd=None, device=None, config=None, dtype=None, no_init=False):
if no_init:
return
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)
@@ -215,6 +218,19 @@ class VAE:
self.patcher = ldm_patched.modules.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
def clone(self):
n = VAE(no_init=True)
n.patcher = self.patcher.clone()
n.memory_used_encode = self.memory_used_encode
n.memory_used_decode = self.memory_used_decode
n.downscale_ratio = self.downscale_ratio
n.latent_channels = self.latent_channels
n.first_stage_model = self.first_stage_model
n.device = self.device
n.vae_dtype = self.vae_dtype
n.output_device = self.output_device
return n
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
steps = samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
steps += samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
@@ -242,7 +258,7 @@ class VAE:
samples /= 3.0
return samples
def decode(self, samples_in):
def decode_inner(self, samples_in):
if model_management.VAE_ALWAYS_TILED:
return self.decode_tiled(samples_in).to(self.output_device)
@@ -264,12 +280,19 @@ class VAE:
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples
def decode(self, samples_in):
wrapper = self.patcher.model_options.get('model_vae_decode_wrapper', None)
if wrapper is None:
return self.decode_inner(samples_in)
else:
return wrapper(self.decode_inner, samples_in)
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
model_management.load_model_gpu(self.patcher)
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
return output.movedim(1,-1)
def encode(self, pixel_samples):
def encode_inner(self, pixel_samples):
if model_management.VAE_ALWAYS_TILED:
return self.encode_tiled(pixel_samples)
@@ -291,6 +314,13 @@ class VAE:
return samples
def encode(self, pixel_samples):
wrapper = self.patcher.model_options.get('model_vae_encode_wrapper', None)
if wrapper is None:
return self.encode_inner(pixel_samples)
else:
return wrapper(self.encode_inner, pixel_samples)
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
model_management.load_model_gpu(self.patcher)
pixel_samples = pixel_samples.movedim(-1,1)