mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-10 09:59:57 +00:00
lora and kohyafix
This commit is contained in:
@@ -1,7 +1,33 @@
|
||||
import gradio as gr
|
||||
|
||||
from modules import scripts
|
||||
from ldm_patched.contrib.external_model_downscale import PatchModelAddDownscale
|
||||
from backend.misc.image_resize import adaptive_resize
|
||||
|
||||
|
||||
class PatchModelAddDownscale:
|
||||
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
|
||||
sigma_start = model.model.predictor.percent_to_sigma(start_percent)
|
||||
sigma_end = model.model.predictor.percent_to_sigma(end_percent)
|
||||
|
||||
def input_block_patch(h, transformer_options):
|
||||
if transformer_options["block"][1] == block_number:
|
||||
sigma = transformer_options["sigmas"][0].item()
|
||||
if sigma <= sigma_start and sigma >= sigma_end:
|
||||
h = adaptive_resize(h, round(h.shape[-1] * (1.0 / downscale_factor)), round(h.shape[-2] * (1.0 / downscale_factor)), downscale_method, "disabled")
|
||||
return h
|
||||
|
||||
def output_block_patch(h, hsp, transformer_options):
|
||||
if h.shape[2] != hsp.shape[2]:
|
||||
h = adaptive_resize(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled")
|
||||
return h, hsp
|
||||
|
||||
m = model.clone()
|
||||
if downscale_after_skip:
|
||||
m.set_model_input_block_patch_after_skip(input_block_patch)
|
||||
else:
|
||||
m.set_model_input_block_patch(input_block_patch)
|
||||
m.set_model_output_block_patch(output_block_patch)
|
||||
return (m,)
|
||||
|
||||
|
||||
opPatchModelAddDownscale = PatchModelAddDownscale()
|
||||
|
||||
@@ -11,8 +11,49 @@ import torch
|
||||
from typing import Union
|
||||
|
||||
from modules import shared, sd_models, errors, scripts
|
||||
from ldm_patched.modules.utils import load_torch_file
|
||||
from ldm_patched.modules.sd import load_lora_for_models
|
||||
from backend.utils import load_torch_file
|
||||
from backend.patcher.lora import model_lora_keys_clip, model_lora_keys_unet, load_lora
|
||||
|
||||
|
||||
def load_lora_for_models(model, clip, lora, strength_model, strength_clip, filename='default'):
|
||||
model_flag = type(model.model).__name__ if model is not None else 'default'
|
||||
|
||||
unet_keys = model_lora_keys_unet(model.model) if model is not None else {}
|
||||
clip_keys = model_lora_keys_clip(clip.cond_stage_model) if clip is not None else {}
|
||||
|
||||
lora_unmatch = lora
|
||||
lora_unet, lora_unmatch = load_lora(lora_unmatch, unet_keys)
|
||||
lora_clip, lora_unmatch = load_lora(lora_unmatch, clip_keys)
|
||||
|
||||
if len(lora_unmatch) > 12:
|
||||
print(f'[LORA] LoRA version mismatch for {model_flag}: {filename}')
|
||||
return model, clip
|
||||
|
||||
if len(lora_unmatch) > 0:
|
||||
print(f'[LORA] Loading {filename} for {model_flag} with unmatched keys {list(lora_unmatch.keys())}')
|
||||
|
||||
new_model = model.clone() if model is not None else None
|
||||
new_clip = clip.clone() if clip is not None else None
|
||||
|
||||
if new_model is not None and len(lora_unet) > 0:
|
||||
loaded_keys = new_model.add_patches(lora_unet, strength_model)
|
||||
skipped_keys = [item for item in lora_unet if item not in loaded_keys]
|
||||
if len(skipped_keys) > 12:
|
||||
print(f'[LORA] Mismatch {filename} for {model_flag}-UNet with {len(skipped_keys)} keys mismatched in {len(loaded_keys)} keys')
|
||||
else:
|
||||
print(f'[LORA] Loaded {filename} for {model_flag}-UNet with {len(loaded_keys)} keys at weight {strength_model} (skipped {len(skipped_keys)} keys)')
|
||||
model = new_model
|
||||
|
||||
if new_clip is not None and len(lora_clip) > 0:
|
||||
loaded_keys = new_clip.add_patches(lora_clip, strength_clip)
|
||||
skipped_keys = [item for item in lora_clip if item not in loaded_keys]
|
||||
if len(skipped_keys) > 12:
|
||||
print(f'[LORA] Mismatch {filename} for {model_flag}-CLIP with {len(skipped_keys)} keys mismatched in {len(loaded_keys)} keys')
|
||||
else:
|
||||
print(f'[LORA] Loaded {filename} for {model_flag}-CLIP with {len(loaded_keys)} keys at weight {strength_clip} (skipped {len(skipped_keys)} keys)')
|
||||
clip = new_clip
|
||||
|
||||
return model, clip
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=5)
|
||||
|
||||
Reference in New Issue
Block a user